From 784485f486532a8993ed352aea96646c2c5d87cc Mon Sep 17 00:00:00 2001 From: l00500167 Date: Mon, 25 Jul 2022 17:15:36 +0800 Subject: [PATCH 1/4] add st test Signed-off-by: l00500167 --- .../mindsponge/test_megafold/config/data.yaml | 76 ++ .../test_megafold/config/model.yaml | 218 +++++ .../mindsponge/test_megafold/data/__init__.py | 16 + .../mindsponge/test_megafold/data/dataset.py | 190 ++++ .../mindsponge/test_megafold/data/hhsearch.py | 85 ++ .../mindsponge/test_megafold/data/kalign.py | 97 ++ .../test_megafold/data/msa_query.py | 83 ++ .../test_megafold/data/msa_search.sh | 62 ++ .../test_megafold/data/preprocess.py | 523 ++++++++++ .../test_megafold/data/protein_feature.py | 132 +++ .../test_megafold/data/templates.py | 920 ++++++++++++++++++ .../st/mindsponge/test_megafold/data/utils.py | 40 + .../test_megafold/model/__init__.py | 16 + .../st/mindsponge/test_megafold/model/fold.py | 304 ++++++ .../test_megafold/module/evoformer.py | 131 +++ .../test_megafold/module/fold_wrapcell.py | 160 +++ .../mindsponge/test_megafold/module/head.py | 166 ++++ .../test_megafold/module/loss_module.py | 296 ++++++ .../test_megafold/module/structure.py | 261 +++++ .../module/template_embedding.py | 262 +++++ .../test_megafold/origin_length/T1070-D2.pkl | 1 + .../processed_feature/T1070-D2.pkl | 1 + .../mindsponge/test_megafold/test_megafold.py | 241 +++++ ..._atom_positions_ascend_mixed_precision.npy | Bin 0 -> 22550 bytes .../final_atom_positions_gpu_fp32.npy | Bin 0 -> 44972 bytes 25 files changed, 4281 insertions(+) create mode 100644 tests/st/mindsponge/test_megafold/config/data.yaml create mode 100644 tests/st/mindsponge/test_megafold/config/model.yaml create mode 100644 tests/st/mindsponge/test_megafold/data/__init__.py create mode 100644 tests/st/mindsponge/test_megafold/data/dataset.py create mode 100644 tests/st/mindsponge/test_megafold/data/hhsearch.py create mode 100644 tests/st/mindsponge/test_megafold/data/kalign.py create mode 100644 tests/st/mindsponge/test_megafold/data/msa_query.py create mode 100644 tests/st/mindsponge/test_megafold/data/msa_search.sh create mode 100644 tests/st/mindsponge/test_megafold/data/preprocess.py create mode 100644 tests/st/mindsponge/test_megafold/data/protein_feature.py create mode 100644 tests/st/mindsponge/test_megafold/data/templates.py create mode 100644 tests/st/mindsponge/test_megafold/data/utils.py create mode 100644 tests/st/mindsponge/test_megafold/model/__init__.py create mode 100644 tests/st/mindsponge/test_megafold/model/fold.py create mode 100644 tests/st/mindsponge/test_megafold/module/evoformer.py create mode 100644 tests/st/mindsponge/test_megafold/module/fold_wrapcell.py create mode 100644 tests/st/mindsponge/test_megafold/module/head.py create mode 100644 tests/st/mindsponge/test_megafold/module/loss_module.py create mode 100644 tests/st/mindsponge/test_megafold/module/structure.py create mode 100644 tests/st/mindsponge/test_megafold/module/template_embedding.py create mode 100644 tests/st/mindsponge/test_megafold/origin_length/T1070-D2.pkl create mode 100644 tests/st/mindsponge/test_megafold/processed_feature/T1070-D2.pkl create mode 100644 tests/st/mindsponge/test_megafold/test_megafold.py create mode 100644 tests/st/mindsponge/test_megafold/true_label/final_atom_positions_ascend_mixed_precision.npy create mode 100644 tests/st/mindsponge/test_megafold/true_label/final_atom_positions_gpu_fp32.npy diff --git a/tests/st/mindsponge/test_megafold/config/data.yaml b/tests/st/mindsponge/test_megafold/config/data.yaml new file mode 100644 index 000000000..9b8e66aac --- /dev/null +++ b/tests/st/mindsponge/test_megafold/config/data.yaml @@ -0,0 +1,76 @@ +block_deletion: + msa_fraction_per_block: 0.3 + num_blocks: 5 + randomize_num_blocks: true +common: + random_recycle: false + distillation: false + replace_proportion: 0.0 + masked_msa: + use_masked_msa: True + profile_prob: 0.1 + same_prob: 0.1 + uniform_prob: 0.1 + max_extra_msa: 1024 + msa_cluster_features: true + num_recycle: 4 + reduce_msa_clusters_by_max_templates: true + resample_msa_in_recycling: true + use_templates: true + template_features: + - template_all_atom_positions + - template_sum_probs + - template_aatype + - template_all_atom_masks + - template_domain_names + unsupervised_features: + - aatype + - residue_index + - sequence + - msa + - domain_name + - num_alignments + - seq_length + - between_segment_residues + - deletion_matrix + - template_all_atom_positions + - template_sum_probs + - template_aatype + - template_all_atom_masks + - template_domain_names + supervised_features: + - all_atom_positions + - all_atom_mask + - atom14_atom_exists + - atom14_gt_exists + - atom14_gt_positions + - residx_atom14_to_atom37 + - residx_atom37_to_atom14 + - atom37_atom_exists + - atom14_alt_gt_positions + - atom14_alt_gt_exists + - atom14_atom_is_ambiguous + - rigidgroups_gt_frames + - rigidgroups_gt_exists + - rigidgroups_group_exists + - rigidgroups_group_is_ambiguous + - rigidgroups_alt_gt_frames + - backbone_affine_tensor + - torsion_angles_sin_cos + - alt_torsion_angles_sin_co + - torsion_angles_mask + - pseudo_beta + - pseudo_beta_mask + - chi_mask + - backbone_affine_mask + + +eval: + crop_size: 256 + fixed_size: true + masked_msa_replace_fraction: 0.15 + max_msa_clusters: 128 + max_templates: 4 + num_ensemble: 1 + subsample_templates: true + keep_extra: True \ No newline at end of file diff --git a/tests/st/mindsponge/test_megafold/config/model.yaml b/tests/st/mindsponge/test_megafold/config/model.yaml new file mode 100644 index 000000000..d8a5f3e55 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/config/model.yaml @@ -0,0 +1,218 @@ +is_training: False +msa_channel: 256 +pair_channel: 128 +extra_msa_channel: 64 +max_relative_feature: 32 +recycle_features: True +recycle_pos: True +seq_channel: 384 +prev_pos: + min_bin: 3.25 + max_bin: 20.75 + num_bins: 15 +common: + target_feat_dim: 22 + msa_feat_dim: 49 + dgram_dim: 15 + pair_in_dim: 65 + msa_first_row_dim: 256 + prev_pair_dim: 128 + extra_msa_dim: 25 + template_feat_dim: 57 +template: + enabled: True + embed_torsion_angles: True + use_template_unit_vector: True + attention: + gating: False + key_dim: 64 + num_head: 4 + value_dim: 64 + dgram_features: + min_bin: 3.25 + max_bin: 50.75 + num_bins: 39 + template_pair_stack: + num_block: 2 + triangle_attention_starting_node: + dropout_rate: 0.25 + gating: True + key_dim: 64 + num_head: 4 + orientation: 'per_row' + shared_dropout: True + value_dim: 64 + triangle_attention_ending_node: + dropout_rate: 0.25 + gating: True + key_dim: 64 + num_head: 4 + orientation: 'per_column' + shared_dropout: True + value_dim: 64 + triangle_multiplication_outgoing: + dropout_rate: 0.25 + equation: 'ikc,jkc->ijc' + num_intermediate_channel: 64 + orientation: 'per_row' + shared_dropout: True + triangle_multiplication_incoming: + dropout_rate: 0.25 + equation: 'kjc,kic->ijc' + num_intermediate_channel: 64 + orientation: 'per_row' + shared_dropout: True + pair_transition: + dropout_rate: 0.0 + num_intermediate_factor: 2 + orientation: 'per_row' + shared_dropout: True +evoformer: + msa_stack_num: 1 + extra_msa_stack_num: 1 + msa_row_attention_with_pair_bias: + dropout_rate: 0.15 # 0.15 + gating: True + num_head: 8 + orientation: 'per_row' + shared_dropout: True + msa_column_attention: + dropout_rate: 0.0 + gating: True + num_head: 8 + orientation: 'per_column' + shared_dropout: True + msa_transition: + dropout_rate: 0.0 + num_intermediate_factor: 4 + orientation: 'per_row' + shared_dropout: True + outer_product_mean: + chunk_size: 128 + dropout_rate: 0.0 + num_outer_channel: 32 + orientation: 'per_row' + shared_dropout: True + triangle_attention_starting_node: + dropout_rate: 0.25 # 0.25 + gating: True + num_head: 4 + orientation: 'per_row' + shared_dropout: True + triangle_attention_ending_node: + dropout_rate: 0.25 # 0.25 + gating: True + num_head: 4 + orientation: 'per_column' + shared_dropout: True + triangle_multiplication_outgoing: + dropout_rate: 0.25 # 0.25 + equation: 'ikc,jkc->ijc' + num_intermediate_channel: 128 + orientation: 'per_row' + shared_dropout: True + triangle_multiplication_incoming: + dropout_rate: 0.25 # 0.25 + equation: 'kjc,kic->ijc' + num_intermediate_channel: 128 + orientation: 'per_row' + shared_dropout: True + pair_transition: + dropout_rate: 0.0 + num_intermediate_factor: 4 + orientation: 'per_row' + shared_dropout': True +structure_module: + num_layer: 8 + fape: + clamp_distance: 10.0 + clamp_type: 'relu' + loss_unit_distance: 10.0 + angle_norm_weight: 0.01 + chi_weight: 0.5 + clash_overlap_tolerance: 1.5 + compute_in_graph_metrics: True + dropout: 0.1 + num_channel: 384 + num_head: 12 + num_layer_in_transition: 3 + num_point_qk: 4 + num_point_v: 8 + num_scalar_qk: 16 + num_scalar_v: 16 + position_scale: 10.0 + sidechain: + atom_clamp_distance: 10.0 + num_channel: 128 + num_residual_block: 2 + weight_frac: 0.5 + length_scale: 10. + structural_violation_loss_weight: 1.0 + violation_tolerance_factor: 12.0 + weight: 1.0 +slice: + seq_256: + template_embedding: 0 + template_pair_stack: + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + extra_msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_global_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 + msa_stack: + msa_transition: 0 + msa_row_attention_with_pair_bias: 0 + msa_column_attention: 0 + outer_product_mean: 0 + triangle_attention_starting_node: 0 + triangle_attention_ending_node: 0 + pair_transition: 0 +heads: + resolution: 1 + predicted_lddt: + filter_by_resolution: True + max_resolution: 3.0 + min_resolution: 0.1 + num_bins: 50 + num_channels: 128 + weight: 0.01 + distogram: + first_break: 2.3125 + last_break: 21.6875 + num_bins: 64 + weight: 0.3 + masked_msa: + num_output: 23 + weight: 2.0 + predicted_aligned_error: + max_error_bin: 31.0 + num_bins: 64 + num_channels: 128 + filter_by_resolution: True + min_resolution: 0.1 + max_resolution: 3.0 + weight: 0.0 + experimentally_resolved: + filter_by_resolution: True + max_resolution: 3.0 + min_resolution: 0.1 + weight: 0.01 + structure_module: + fape: + clamp_distance: 10.0 + loss_unit_distance: 10.0 + angle_norm_weight: 0.01 + chi_weight: 0.5 + clash_overlap_tolerance: 1.5 + sidechain: + atom_clamp_distance: 10.0 + weight_frac: 0.5 + length_scale: 10.0 + structural_violation_loss_weight: 1.0 + violation_tolerance_factor: 12.0 \ No newline at end of file diff --git a/tests/st/mindsponge/test_megafold/data/__init__.py b/tests/st/mindsponge/test_megafold/data/__init__.py new file mode 100644 index 000000000..4c2835f9e --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''init''' +from .preprocess import Feature diff --git a/tests/st/mindsponge/test_megafold/data/dataset.py b/tests/st/mindsponge/test_megafold/data/dataset.py new file mode 100644 index 000000000..32861e105 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/dataset.py @@ -0,0 +1,190 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train dataset""" +import datetime +import os +import pickle +import time +import numpy as np +from mindspore import dataset as ds +from mindspore.communication import get_rank + +from mindsponge.common.residue_constants import make_atom14_dists_bounds +from mindsponge.common.protein import from_pdb_string +from mindsponge.common.utils import make_atom14_positions +from mindsponge.data.data_transform import pseudo_beta_fn, atom37_to_frames, atom37_to_torsion_angles +from preprocess import Feature + + +def create_dataset(train_data_dir, raw_feature_dir, names, data_cfg, center_name_path, shuffle=False, + num_parallel_worker=4, + is_parallel=False, mixed_precision=False): + """create train dataset""" + column_name = ["target_feat", "msa_feat", "msa_mask", "seq_mask_batch", "aatype_batch", + "template_aatype", "template_all_atom_masks", + "template_all_atom_positions", "template_mask", + "template_pseudo_beta_mask", "template_pseudo_beta", "extra_msa", "extra_has_deletion", + "extra_deletion_value", "extra_msa_mask", "residx_atom37_to_atom14", + "atom37_atom_exists_batch", "residue_index_batch", "prev_pos", + "prev_msa_first_row", "prev_pair", "pseudo_beta_gt", + "pseudo_beta_mask_gt", "all_atom_mask_gt", + "true_msa", "bert_mask", "residue_index", "seq_mask", + "atom37_atom_exists", "aatype", "residx_atom14_to_atom37", + "atom14_atom_exists", "backbone_affine_tensor", "backbone_affine_mask", + "atom14_gt_positions", "atom14_alt_gt_positions", + "atom14_atom_is_ambiguous", "atom14_gt_exists", "atom14_alt_gt_exists", + "all_atom_positions", "rigidgroups_gt_frames", "rigidgroups_gt_exists", + "rigidgroups_alt_gt_frames", "torsion_angles_sin_cos_gt", "chi_mask", "atomtype_radius", + "restype_atom14_bond_lower_bound", "restype_atom14_bond_upper_bound", "use_clamped_fape", + "filter_by_solution", "prot_name_index"] + + dataset_generator = DatasetGenerator(train_data_dir, raw_feature_dir, names, data_cfg, center_name_path, + mixed_precision) + ds.config.set_prefetch_size(1) + + if is_parallel: + rank_id = get_rank() % 8 + rank_size = 8 + train_dataset = ds.GeneratorDataset(source=dataset_generator, column_names=column_name, + num_parallel_workers=num_parallel_worker, shuffle=shuffle, + num_shards=rank_size, + shard_id=rank_id, max_rowsize=16) + else: + train_dataset = ds.GeneratorDataset(source=dataset_generator, column_names=column_name, + num_parallel_workers=num_parallel_worker, shuffle=shuffle, max_rowsize=16) + return train_dataset + + +class DatasetGenerator: + """dataset generator""" + def __init__(self, train_data_dir, raw_feature_dir, names, data_cfg, resolution_data, mixed_precision): + self.t1 = time.time() + print("start dataset init: ", str(datetime.datetime.now())) + self.data_cfg = data_cfg + self.num_residues = data_cfg.eval.crop_size + self.train_data_dir = train_data_dir + self.raw_feature_dir = raw_feature_dir + self.names = [name.replace("\n", "") for name in names] + self.mixed_precision = mixed_precision + + self.resolution_info = resolution_data + print("end dataset init: ", time.time() - self.t1) + + def __getitem__(self, index): + prot_name = self.names[index] + prot_name_index = np.asarray([index]).astype(np.int32) + arrays, prev_pos, prev_msa_first_row, prev_pair, label_arrays = self._get_train_data(prot_name) + atomtype_radius = np.array( + [1.55, 1.7, 1.7, 1.7, 1.52, 1.7, 1.7, 1.7, 1.52, 1.52, 1.8, 1.7, 1.7, 1.7, 1.55, 1.55, + 1.52, 1.52, 1.8, 1.7, 1.7, 1.7, 1.7, 1.55, 1.55, 1.55, 1.52, 1.52, 1.7, 1.55, 1.55, + 1.52, 1.7, 1.7, 1.7, 1.55, 1.52]) + restype_atom14_bond_lower_bound, restype_atom14_bond_upper_bound, _ = \ + make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=12.0) + use_clamped_fape = np.random.binomial(1, 0.9, size=1) + filter_by_solution = np.array(1.0) + extra_feats = [atomtype_radius, restype_atom14_bond_lower_bound, + restype_atom14_bond_upper_bound, use_clamped_fape, filter_by_solution, prot_name_index] + dtype = np.float32 + if self.mixed_precision: + dtype = np.float16 + extra_feats = [array.astype(dtype) for array in extra_feats] + all_feats = arrays + [prev_pos, prev_msa_first_row, prev_pair] + label_arrays + extra_feats + + return tuple(all_feats) + + def __len__(self): + return len(self.names) + + def _get_solution_flag(self, prot_name): + """get resolution data""" + prot_new_name = prot_name.rsplit('_', 1)[0] + if prot_new_name not in self.resolution_info: + return np.array(1.0).astype(np.float32) + resolution = float(self.resolution_info[prot_new_name]['resolution']) + nmr = self.resolution_info[prot_new_name]['method'] + if resolution < 3 and nmr != 'NMR': + return np.array(1.0).astype(np.float32) + return np.array(0.0).astype(np.float32) + + def _get_train_labels(self, prot_pdb): + """get train labels""" + aatype = prot_pdb.aatype + seq_len = len(aatype) + atom37_positions = prot_pdb.atom_positions.astype(np.float32) + atom37_mask = prot_pdb.atom_mask.astype(np.float32) + + # get ground truth of atom14 + label_features = {'aatype': aatype, + 'all_atom_positions': atom37_positions, + 'all_atom_mask': atom37_mask} + + atom14_features = make_atom14_positions(aatype, atom37_mask, atom37_positions) + atom14_keys = ["atom14_atom_exists", "atom14_gt_exists", "atom14_gt_positions", "residx_atom14_to_atom37", + "residx_atom37_to_atom14", "atom37_atom_exists", "atom14_alt_gt_positions", + "atom14_alt_gt_exists", "atom14_atom_is_ambiguous"] + for index, array in enumerate(atom14_features): + label_features[atom14_keys[index]] = array + + # get ground truth of rigid groups + rigidgroups_label_feature = atom37_to_frames(aatype, atom37_positions, atom37_mask, is_affine=True) + label_features.update(rigidgroups_label_feature) + + # get ground truth of angle + angle_label_feature = atom37_to_torsion_angles(aatype.reshape((1, -1)), + atom37_positions.reshape((1, seq_len, 37, 3)), + atom37_mask.reshape((1, seq_len, 37)), True) + label_features.update(angle_label_feature) + + # get pseudo_beta, pseudo_beta_mask + pseudo_beta, pseudo_beta_mask = pseudo_beta_fn(aatype, atom37_positions, atom37_mask) + label_features["pseudo_beta"] = pseudo_beta + label_features["pseudo_beta_mask"] = pseudo_beta_mask + label_features["chi_mask"] = label_features.get("torsion_angles_mask")[:, 3:] + label_features['torsion_angles_sin_cos'] = label_features.get('torsion_angles_sin_cos')[:, 3:, :] + label_features['backbone_affine_mask'] = pseudo_beta_mask + label_features.pop("aatype") + + return label_features + + def _get_train_data(self, prot_name): + """get train data""" + pdb_path = os.path.join(self.train_data_dir, prot_name + '.pdb') + with open(pdb_path, 'r') as f: + prot_pdb = from_pdb_string(f.read()) + f.close() + with open(os.path.join(self.raw_feature_dir, prot_name + '.pkl'), "rb") as f: + raw_feature = pickle.load(f) + f.close() + label_features = self._get_train_labels(prot_pdb) + seed = global_seed() + raw_feature.update(label_features) + processed_feature = Feature(self.data_cfg, raw_feature, is_training=True) + processed_feat = processed_feature.pipeline(self.data_cfg, self.mixed_precision, seed=seed) + return processed_feat + + +class SeedMaker: + """Return unique seeds.""" + + def __init__(self, initial_seed=0): + self.next_seed = initial_seed + + def __call__(self): + i = self.next_seed + self.next_seed += 1 + return i + + +global_seed = SeedMaker() diff --git a/tests/st/mindsponge/test_megafold/data/hhsearch.py b/tests/st/mindsponge/test_megafold/data/hhsearch.py new file mode 100644 index 000000000..5127a7610 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/hhsearch.py @@ -0,0 +1,85 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +HHsearch tools. +""" + +import glob +import os +import stat +import subprocess + +from absl import logging +from mindsponge.data.utils import timing +from data.utils import tmpdir_manager + + +class HHSearch: + """Python wrapper of the HHsearch binary. + Cited from https://github.com/deepmind/alphafold. + """ + + def __init__(self, + binary_path, + databases, + maxseq=1_000_000): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + raise ValueError(f'Could not find HHsearch database {database_path}') + + def query(self, a3m): + """Queries the database using HHsearch using a given a3m.""" + with tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, 'query.a3m') + hhr_path = os.path.join(query_tmp_dir, 'output.hhr') + with os.fdopen(os.open(input_path, os.O_RDWR|os.O_CREAT, stat.S_IRWXU), 'w') as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [self.binary_path, '-i', input_path, '-o', hhr_path, '-maxseq', str(self.maxseq), '-cpu', '8',] + \ + db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with timing('HHsearch query'): + stdout, stderr = process.communicate() + retcode = process.wait() + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError('HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) + with open(hhr_path) as f: + hhr = f.read() + return hhr diff --git a/tests/st/mindsponge/test_megafold/data/kalign.py b/tests/st/mindsponge/test_megafold/data/kalign.py new file mode 100644 index 000000000..3d4455005 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/kalign.py @@ -0,0 +1,97 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Kalign tools. +""" + +import os +import stat +import subprocess + +from absl import logging +from mindsponge.data.utils import timing +from data.utils import tmpdir_manager + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + Cited from "https://github.com/deepmind/alphafold" + Args: + binary_path: The path to the Kalign binary. + """ + self.binary_path = binary_path + + @staticmethod + def to_a3m(sequences): + """Converts sequences to an a3m file.""" + names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append(u'>' + name + u'\n') + a3m.append(sequence + u'\n') + return ''.join(a3m) + + + def align(self, sequences): + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info('Aligning %d sequences', len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError('Kalign requires all sequences to be at least 6 ' + 'residues long. Got %s (%d residues).' % (s, len(s))) + + with tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with os.fdopen(os.open(input_fasta_path, os.O_RDWR|os.O_CREAT, stat.S_IRWXU), 'w') as f: + f.write(self.to_a3m(sequences)) + + cmd = [self.binary_path, '-i', input_fasta_path, '-o', output_a3m_path, '-format', 'fasta',] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with timing('Kalign query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', stdout.decode('utf-8'), stderr.decode('utf-8')) + + if retcode: + raise RuntimeError( + 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m diff --git a/tests/st/mindsponge/test_megafold/data/msa_query.py b/tests/st/mindsponge/test_megafold/data/msa_query.py new file mode 100644 index 000000000..e374a6043 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/msa_query.py @@ -0,0 +1,83 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MSA query tools. +""" + +import os + + +class MmseqQuery: + """Runs the alignment tools""" + + def __init__(self, + database_envdb_dir, + mmseqs_binary, + database_dir, + result_path, + msa_search_sh=os.path.join(os.path.dirname(__file__), + "msa_search.sh")): + """Search the a3m info for a given FASTA file.""" + + self.database_envdb_dir = database_envdb_dir + self.mmseqs_binary = mmseqs_binary + self.database_dir = database_dir + self.result_path = result_path + self.msa_search_sh = msa_search_sh + + @staticmethod + def get_a3mlines(a3m_paths): + """combine a3m files together""" + a3m_lines = {} + for a3m_file in a3m_paths: + update_m, m = True, None + with open(a3m_file, "r") as f: + lines = f.readlines() + for line in lines: + if "\x00" in line: + line = line.replace("\x00", "") + update_m = True + if line.startswith(">") and update_m: + try: + m = int(line.strip()[-1]) + except ValueError: + m = str(line.strip()[-1]) + update_m = False + if m not in a3m_lines: + a3m_lines[m] = [] + a3m_lines.get(m).append(line) + a3m_lines = ["".join(a3m_lines.get(key)) for key in a3m_lines] + return a3m_lines[0] + + def msa_query(self, fasta_path, result_path): + """main entry for msa_query""" + if self.database_envdb_dir: + command = f"sh {self.msa_search_sh} {self.mmseqs_binary} " + fasta_path + " " + result_path + " " + \ + self.database_dir + " " + "\"\"" + " " + self.database_envdb_dir + " \"1\" \"0\" \"1\"" + else: + command = f"sh {self.msa_search_sh} {self.mmseqs_binary} " + fasta_path + " " + result_path + " " + \ + self.database_dir + " " + "\"\"" + " \"\"" + " \"0\" \"0\" \"1\"" + os.system(command) + a3m_file_path = os.listdir(result_path) + a3m_file_path = [os.path.join(result_path, x) for x in a3m_file_path if x.endswith("a3m")] + return a3m_file_path + + def aligned_a3m_files(self, input_fasta_path, result_path): + """Runs alignment tools on the input sequence and creates features.""" + + a3m_file_paths = self.msa_query(fasta_path=input_fasta_path, result_path=result_path) + a3m_lines = self.get_a3mlines(a3m_paths=a3m_file_paths) + + return a3m_lines diff --git a/tests/st/mindsponge/test_megafold/data/msa_search.sh b/tests/st/mindsponge/test_megafold/data/msa_search.sh new file mode 100644 index 000000000..85c3c182c --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/msa_search.sh @@ -0,0 +1,62 @@ +#!/bin/bash -e +# cited from https://github.com/sokrypton/ColabFold +MMSEQS="$1" #mmseqs +QUERY="$2" #"/path/QUERY.fasta" +BASE="$3" #"./result/" +DB1="$4" #"uniref30_2103_db" +DB2="$5" #"" +DB3="$6" #"colabfold_envdb_202108_db" +USE_ENV="$7" #1 +USE_TEMPLATES="$8" #0 +FILTER="${9}" #1 +EXPAND_EVAL=inf +ALIGN_EVAL=10 +DIFF=3000 +QSC=-20.0 +MAX_ACCEPT=1000000 +time=$(date ) +echo "${time}" +if [ "${FILTER}" = "1" ]; then +# 0.1 was not used in benchmarks due to POSIX shell bug in line above +# EXPAND_EVAL=0.1 + ALIGN_EVAL=10 + QSC=0.8 + MAX_ACCEPT=100000 +fi +export MMSEQS_CALL_DEPTH=1 +SEARCH_PARAM="--num-iterations 3 --db-load-mode 2 -a -s 8 -e 0.1 --max-seqs 10000" +FILTER_PARAM="--filter-msa ${FILTER} --filter-min-enable 1000 --diff ${DIFF} --qid 0.0,0.2,0.4,0.6,0.8,1.0 --qsc 0 --max-seq-id 0.95" +EXPAND_PARAM="--expansion-mode 0 -e ${EXPAND_EVAL} --expand-filter-clusters ${FILTER} --max-seq-id 0.95" +mkdir -p "${BASE}" +"${MMSEQS}" createdb "${QUERY}" "${BASE}/qdb" +"${MMSEQS}" search "${BASE}/qdb" "${DB1}" "${BASE}/res" "${BASE}/tmp" $SEARCH_PARAM +"${MMSEQS}" expandaln "${BASE}/qdb" "${DB1}.idx" "${BASE}/res" "${DB1}.idx" "${BASE}/res_exp" --db-load-mode 2 ${EXPAND_PARAM} + +"${MMSEQS}" mvdb "${BASE}/tmp/latest/profile_1" "${BASE}/prof_res" +"${MMSEQS}" lndb "${BASE}/qdb_h" "${BASE}/prof_res_h" +"${MMSEQS}" align "${BASE}/prof_res" "${DB1}.idx" "${BASE}/res_exp" "${BASE}/res_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a +"${MMSEQS}" filterresult "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign" "${BASE}/res_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100 +"${MMSEQS}" result2msa "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign_filter" "${BASE}/uniref.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM} +"${MMSEQS}" rmdb "${BASE}/res_exp_realign" +"${MMSEQS}" rmdb "${BASE}/res_exp" +"${MMSEQS}" rmdb "${BASE}/res" +"${MMSEQS}" rmdb "${BASE}/res_exp_realign_filter" +if [ "${USE_TEMPLATES}" = "1" ]; then + "${MMSEQS}" search "${BASE}/prof_res" "${DB2}" "${BASE}/res_pdb" "${BASE}/tmp" --db-load-mode 2 -s 7.5 -a -e 0.1 + echo "-----------------------in here" + echo "${BASE}/${DB2}.m8" + "${MMSEQS}" convertalis "${BASE}/prof_res" "${DB2}.idx" "${BASE}/res_pdb" "${BASE}/${DB2}.m8" --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar --db-load-mode 2 + "${MMSEQS}" rmdb "${BASE}/res_pdb" +fi +if [ "${USE_ENV}" = "1" ]; then + "${MMSEQS}" search "${BASE}/prof_res" "${DB3}" "${BASE}/res_env" "${BASE}/tmp" $SEARCH_PARAM + "${MMSEQS}" expandaln "${BASE}/prof_res" "${DB3}.idx" "${BASE}/res_env" "${DB3}.idx" "${BASE}/res_env_exp" -e ${EXPAND_EVAL} --expansion-mode 0 --db-load-mode 2 + "${MMSEQS}" align "${BASE}/tmp/latest/profile_1" "${DB3}.idx" "${BASE}/res_env_exp" "${BASE}/res_env_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a + "${MMSEQS}" filterresult "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign" "${BASE}/res_env_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100 + "${MMSEQS}" result2msa "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign_filter" "${BASE}/bfd.mgnify30.metaeuk30.smag30.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM} + "${MMSEQS}" rmdb "${BASE}/res_env_exp_realign_filter" + "${MMSEQS}" rmdb "${BASE}/res_env_exp_realign" +fi + +time=$(date ) +echo "${time}" diff --git a/tests/st/mindsponge/test_megafold/data/preprocess.py b/tests/st/mindsponge/test_megafold/data/preprocess.py new file mode 100644 index 000000000..5c208ffa3 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/preprocess.py @@ -0,0 +1,523 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data process""" +import numpy as np +from mindspore import Tensor + +from mindsponge.data.data_transform import one_hot, correct_msa_restypes, randomly_replace_msa_with_unknown, \ + fix_templates_aatype, pseudo_beta_fn, make_atom14_masks, \ + block_delete_msa_indices, sample_msa, make_masked_msa, \ + nearest_neighbor_clusters, summarize_clusters, crop_extra_msa, \ + make_msa_feat, random_crop_to_size +from mindsponge.common.residue_constants import atom_type_num + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' +NUM_SEQ = "length msa placeholder" +_MSA_FEATURE_NAMES = ['msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'] + +FEATURES = { + # Static features of a protein sequence + "aatype": (np.float32, [NUM_RES, 21]), + "between_segment_residues": (np.int64, [NUM_RES, 1]), + "deletion_matrix": (np.float32, [NUM_SEQ, NUM_RES, 1]), + "msa": (np.int64, [NUM_SEQ, NUM_RES, 1]), + "num_alignments": (np.int64, [NUM_RES, 1]), + "residue_index": (np.int64, [NUM_RES, 1]), + "seq_length": (np.int64, [NUM_RES, 1]), + "all_atom_positions": (np.float32, [NUM_RES, atom_type_num, 3]), + "all_atom_mask": (np.int64, [NUM_RES, atom_type_num]), + "resolution": (np.float32, [1]), + "template_domain_names": (str, [NUM_TEMPLATES]), + "template_sum_probs": (np.float32, [NUM_TEMPLATES, 1]), + "template_aatype": (np.float32, [NUM_TEMPLATES, NUM_RES, 22]), + "template_all_atom_positions": (np.float32, [NUM_TEMPLATES, NUM_RES, atom_type_num, 3]), + "template_all_atom_masks": (np.float32, [NUM_TEMPLATES, NUM_RES, atom_type_num, 1]), + "atom14_atom_exists": (np.float32, [NUM_RES, 14]), + "atom14_gt_exists": (np.float32, [NUM_RES, 14]), + "atom14_gt_positions": (np.float32, [NUM_RES, 14, 3]), + "residx_atom14_to_atom37": (np.float32, [NUM_RES, 14]), + "residx_atom37_to_atom14": (np.float32, [NUM_RES, 37]), + "atom37_atom_exists": (np.float32, [NUM_RES, 37]), + "atom14_alt_gt_positions": (np.float32, [NUM_RES, 14, 3]), + "atom14_alt_gt_exists": (np.float32, [NUM_RES, 14]), + "atom14_atom_is_ambiguous": (np.float32, [NUM_RES, 14]), + "rigidgroups_gt_frames": (np.float32, [NUM_RES, 8, 12]), + "rigidgroups_gt_exists": (np.float32, [NUM_RES, 8]), + "rigidgroups_group_exists": (np.float32, [NUM_RES, 8]), + "rigidgroups_group_is_ambiguous": (np.float32, [NUM_RES, 8]), + "rigidgroups_alt_gt_frames": (np.float32, [NUM_RES, 8, 12]), + "backbone_affine_tensor": (np.float32, [NUM_RES, 7]), + "torsion_angles_sin_cos": (np.float32, [NUM_RES, 4, 2]), + "torsion_angles_mask": (np.float32, [NUM_RES, 7]), + "pseudo_beta": (np.float32, [NUM_RES, 3]), + "pseudo_beta_mask": (np.float32, [NUM_RES,]), + "chi_mask": (np.float32, [NUM_RES, 4]), + "backbone_affine_mask": (np.float32, [NUM_RES,]), +} + +feature_list = { + 'aatype': [NUM_RES], + 'all_atom_mask': [NUM_RES, None], + 'all_atom_positions': [NUM_RES, None, None], + 'alt_chi_angles': [NUM_RES, None], + 'atom14_alt_gt_exists': [NUM_RES, None], + 'atom14_alt_gt_positions': [NUM_RES, None, None], + 'atom14_atom_exists': [NUM_RES, None], + 'atom14_atom_is_ambiguous': [NUM_RES, None], + 'atom14_gt_exists': [NUM_RES, None], + 'atom14_gt_positions': [NUM_RES, None, None], + 'atom37_atom_exists': [NUM_RES, None], + 'backbone_affine_mask': [NUM_RES], + 'backbone_affine_tensor': [NUM_RES, None], + 'bert_mask': [NUM_MSA_SEQ, NUM_RES], + 'chi_angles': [NUM_RES, None], + 'chi_mask': [NUM_RES, None], + 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_row_mask': [NUM_EXTRA_SEQ], + 'is_distillation': [], + 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], + 'msa_mask': [NUM_MSA_SEQ, NUM_RES], + 'msa_row_mask': [NUM_MSA_SEQ], + 'pseudo_beta': [NUM_RES, None], + 'pseudo_beta_mask': [NUM_RES], + 'random_crop_to_size_seed': [None], + 'residue_index': [NUM_RES], + 'residx_atom14_to_atom37': [NUM_RES, None], + 'residx_atom37_to_atom14': [NUM_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [NUM_RES, None, None], + 'rigidgroups_group_exists': [NUM_RES, None], + 'rigidgroups_group_is_ambiguous': [NUM_RES, None], + 'rigidgroups_gt_exists': [NUM_RES, None], + 'rigidgroups_gt_frames': [NUM_RES, None, None], + 'seq_length': [], + 'seq_mask': [NUM_RES], + 'target_feat': [NUM_RES, None], + 'template_aatype': [NUM_TEMPLATES, NUM_RES], + 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_positions': [ + NUM_TEMPLATES, NUM_RES, None, None], + 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], + 'template_backbone_affine_tensor': [ + NUM_TEMPLATES, NUM_RES, None], + 'template_mask': [NUM_TEMPLATES], + 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], + 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], + 'template_sum_probs': [NUM_TEMPLATES, None], + 'true_msa': [NUM_MSA_SEQ, NUM_RES], + 'torsion_angles_sin_cos': [NUM_RES, None, None] +} + + +def feature_shape(feature_name, num_residues, msa_length, num_templates, features=None): + """Get the shape for the given feature name.""" + features = features or FEATURES + if feature_name.endswith("_unnormalized"): + feature_name = feature_name[:-13] + unused_dtype, raw_sizes = features.get(feature_name, (None, None)) + replacements = {NUM_RES: num_residues, + NUM_SEQ: msa_length} + + if num_templates is not None: + replacements[NUM_TEMPLATES] = num_templates + + sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] + for dimension in sizes: + if isinstance(dimension, str): + raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( + feature_name, raw_sizes, replacements)) + size_r = [int(x) for x in sizes] + return size_r + + +def parse_reshape_logic(parsed_features, features, num_template, key=None): + """Transforms parsed serial features to the correct shape.""" + # Find out what is the number of sequences and the number of alignments. + num_residues = np.reshape(parsed_features['seq_length'].astype(np.int32), (-1,))[0] + + if "num_alignments" in parsed_features: + num_msa = np.reshape(parsed_features["num_alignments"].astype(np.int32), (-1,))[0] + else: + num_msa = 0 + + if key is not None and "key" in features: + parsed_features["key"] = [key] # Expand dims from () to (1,). + + # Reshape the arrays according to the sequence length and num alignments. + for k, v in parsed_features.items(): + new_shape = feature_shape( + feature_name=k, + num_residues=num_residues, + msa_length=num_msa, + num_templates=num_template, + features=features) + new_shape_size = 1 + for dim in new_shape: + new_shape_size *= dim + + if np.size(v) != new_shape_size: + raise ValueError("the size of feature {} ({}) could not be reshaped into {}" + "".format(k, np.size(v), new_shape)) + + if "template" not in k: + # Make sure the feature we are reshaping is not empty. + if np.size(v) <= 0: + raise ValueError("The feature {} is not empty.".format(k)) + parsed_features[k] = np.reshape(v, new_shape) + + return parsed_features + + +def _make_features_metadata(feature_names): + """Makes a feature name to type and shape mapping from a list of names.""" + # Make sure these features are always read. + required_features = ["sequence", "domain_name", "template_domain_names"] + feature_names = list(set(feature_names) - set(required_features)) + + features_metadata = {name: FEATURES.get(name) for name in feature_names} + return features_metadata + + +def np_to_array_dict(np_example, features): + """Creates dict of arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + array_dict = {k: v for k, v in np_example.items() if k in features_metadata} + if "template_domain_names" in np_example: + num_template = len(np_example["template_domain_names"]) + else: + num_template = 0 + + # Ensures shapes are as expected. Needed for setting size of empty features + # e.g. when no template hits were found. + array_dict = parse_reshape_logic(array_dict, features_metadata, num_template) + array_dict['template_mask'] = np.ones([num_template], np.float32) + return array_dict + + +class Feature: + """feature process""" + + def __init__(self, cfg, raw_feature=None, is_training=False): + if raw_feature and isinstance(raw_feature, dict): + self.ensemble_num = 0 + if 'deletion_matrix_int' in raw_feature: + raw_feature['deletion_matrix'] = (raw_feature.pop('deletion_matrix_int').astype(np.float32)) + feature_names = cfg.common.unsupervised_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + self.is_training = is_training + if self.is_training: + feature_names += cfg.common.supervised_features + raw_feature = np_to_array_dict(np_example=raw_feature, features=feature_names) + + for key in raw_feature: + setattr(self, key, raw_feature[key]) + + def non_ensemble(self, distillation=False, replace_proportion=0.0, use_templates=True): + """non ensemble""" + setattr(self, "msa", correct_msa_restypes(self.msa)) + setattr(self, "is_distillation", np.array(float(distillation), dtype=np.float32)) + # convert int64 to int32 + for k, v in vars(self).items(): + if k not in ("ensemble_num", "is_training"): + if v.dtype == np.int64: + setattr(self, k, v.astype(np.int32)) + aatype = np.argmax(self.aatype, axis=-1) + setattr(self, "aatype", aatype.astype(np.int32)) + data = vars(self) + for k in ['msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix', + 'resolution', 'between_segment_residues', 'residue_index', 'template_all_atom_masks']: + if k in data: + final_dim = data[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + setattr(self, k, np.squeeze(data[k], axis=-1)) + # Remove fake sequence dimension + for k in ['seq_length', 'num_alignments']: + if k in data: + setattr(self, k, data[k][0]) + + msa, aatype = randomly_replace_msa_with_unknown(self.msa, self.aatype, replace_proportion) + setattr(self, "msa", msa) + setattr(self, "aatype", aatype) + # seq_mask + seq_mask = np.ones(self.aatype.shape, dtype=np.float32) + setattr(self, "seq_mask", seq_mask) + # msa_mask and msa_row_mask + msa_mask = np.ones(self.msa.shape, dtype=np.float32) + msa_row_mask = np.ones(self.msa.shape[0], dtype=np.float32) + setattr(self, "msa_mask", msa_mask) + setattr(self, "msa_row_mask", msa_row_mask) + if 'hhblits_profile' not in data: + # Compute the profile for every residue (over all MSA sequences). + setattr(self, 'hhblits_profile', np.mean(one_hot(22, self.msa), axis=0)) + + if use_templates: + template_aatype = fix_templates_aatype(self.template_aatype) + setattr(self, "template_aatype", template_aatype) + template_pseudo_beta, template_pseudo_beta_mask = pseudo_beta_fn(self.template_aatype, + self.template_all_atom_positions, + self.template_all_atom_masks) + setattr(self, "template_pseudo_beta", template_pseudo_beta) + setattr(self, "template_pseudo_beta_mask", template_pseudo_beta_mask) + + atom14_atom_exists, residx_atom14_to_atom37, residx_atom37_to_atom14, atom37_atom_exists = \ + make_atom14_masks(self.aatype) + setattr(self, "atom14_atom_exists", atom14_atom_exists) + setattr(self, "residx_atom14_to_atom37", residx_atom14_to_atom37) + setattr(self, "residx_atom37_to_atom14", residx_atom37_to_atom14) + setattr(self, "atom37_atom_exists", atom37_atom_exists) + + def ensemble(self, data, msa_fraction_per_block=0.3, randomize_num_blocks=True, num_blocks=5, keep_extra=True, + max_msa_clusters=124, masked_msa=None, uniform_prob=0.1, profile_prob=0.1, same_prob=0.1, + replace_fraction=0.15, msa_cluster_features=True, max_extra_msa=1024, crop_size=256, max_templates=4, + subsample_templates=True, fixed_size=True, seed=0, random_recycle=False): + """ensemble""" + self.ensemble_num += 1 + if self.is_training: + keep_indices = block_delete_msa_indices(data["msa"], msa_fraction_per_block, randomize_num_blocks, + num_blocks) + for k in _MSA_FEATURE_NAMES: + if k in data: + data[k] = data[k][keep_indices] + # exist numpy random op + is_sel, not_sel_seq, sel_seq = sample_msa(data["msa"], max_msa_clusters) + for k in _MSA_FEATURE_NAMES: + if k in data: + if keep_extra and not is_sel: + new_shape = list(data[k].shape) + new_shape[0] = 1 + data['extra_' + k] = np.zeros(new_shape) + elif keep_extra and is_sel: + data['extra_' + k] = data[k][not_sel_seq] + if k == 'msa': + data['extra_msa'] = data['extra_msa'].astype(np.int32) + data[k] = data[k][sel_seq] + if masked_msa: + data["bert_mask"], data["true_msa"], data["msa"] = make_masked_msa(data["msa"], data["hhblits_profile"], + uniform_prob, profile_prob, same_prob, + replace_fraction) + if msa_cluster_features: + data["extra_cluster_assignment"] = nearest_neighbor_clusters(data["msa_mask"], data["msa"], + data["extra_msa_mask"], data["extra_msa"]) + data["cluster_profile"], data["cluster_deletion_mean"] = summarize_clusters(data["msa"], data["msa_mask"], + data[ + "extra_cluster_assignment"], + data["extra_msa_mask"], + data["extra_msa"], + data["extra_deletion_matrix"], + data["deletion_matrix"]) + + if max_extra_msa: + select_indices = crop_extra_msa(data["extra_msa"], max_extra_msa) + if select_indices: + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in data: + data['extra_' + k] = data['extra_' + k][select_indices] + else: + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in data: + del data['extra_' + k] + data["extra_has_deletion"], data["extra_deletion_value"], data["msa_feat"], data["target_feat"] = make_msa_feat( + data["between_segment_residues"], data["aatype"], data["msa"], data["deletion_matrix"], + data["cluster_deletion_mean"], data["cluster_profile"], data["extra_deletion_matrix"]) + if fixed_size: + data = {k: v for k, v in data.items() if k in feature_list} + + num_res_crop_size, num_templates_crop_size_int, num_res_crop_start, num_res_crop_size_int, \ + templates_crop_start, templates_select_indices = random_crop_to_size( + data["seq_length"], data["template_mask"], crop_size, max_templates, + subsample_templates, seed, random_recycle) + for k, v in data.items(): + if k not in feature_list or ('template' not in k and NUM_RES not in feature_list.get(k)): + continue + + # randomly permute the templates before cropping them. + if k.startswith('template') and subsample_templates: + v = v[templates_select_indices] + + crop_sizes = [] + crop_starts = [] + for i, (dim_size, dim) in enumerate(zip(feature_list.get(k), v.shape)): + is_num_res = (dim_size == NUM_RES) + if i == 0 and k.startswith('template'): + crop_size_ = num_templates_crop_size_int + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size_ = (num_res_crop_size_int if is_num_res else (-1 if dim is None else dim)) + crop_sizes.append(crop_size_) + crop_starts.append(crop_start) + if len(v.shape) == 1: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0]] + elif len(v.shape) == 2: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1]] + elif len(v.shape) == 3: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1], + crop_starts[2]:crop_starts[2] + crop_sizes[2]] + else: + data[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], + crop_starts[1]:crop_starts[1] + crop_sizes[1], + crop_starts[2]:crop_starts[2] + crop_sizes[2], + crop_starts[3]:crop_starts[3] + crop_sizes[3]] + + data["seq_length"] = num_res_crop_size + + pad_size_map = { + NUM_RES: crop_size, + NUM_MSA_SEQ: max_msa_clusters, + NUM_EXTRA_SEQ: max_extra_msa, + NUM_TEMPLATES: max_templates, + } + + for k, v in data.items(): + if k == 'extra_cluster_assignment': + continue + shape = list(v.shape) + schema = feature_list.get(k) + assert len(shape) == len( + schema), f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}' + + pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + if padding: + data[k] = np.pad(v, padding) + data[k].reshape(pad_size) + else: + for k, v in data.items(): + if k.startswith('template_'): + data[k] = v[:max_templates] + return data + + def process_res(self, features, res, dtype): + """process result""" + arrays, prev_pos, prev_msa_first_row, prev_pair = res + if self.is_training: + label_keys = ["pseudo_beta", "pseudo_beta_mask", "all_atom_mask", + "true_msa", "bert_mask", "residue_index", "seq_mask", + "atom37_atom_exists", "aatype", "residx_atom14_to_atom37", + "atom14_atom_exists", "backbone_affine_tensor", "backbone_affine_mask", + "atom14_gt_positions", "atom14_alt_gt_positions", + "atom14_atom_is_ambiguous", "atom14_gt_exists", "atom14_alt_gt_exists", + "all_atom_positions", "rigidgroups_gt_frames", "rigidgroups_gt_exists", + "rigidgroups_alt_gt_frames", "torsion_angles_sin_cos", "chi_mask"] + label_arrays = [features[key] for key in label_keys] + label_arrays = [array[0] for array in label_arrays] + label_arrays = [array.astype(dtype) if array.dtype == "float64" else array for array in label_arrays] + label_arrays = [array.astype(dtype) if array.dtype == "float32" else array for array in label_arrays] + res = [arrays, prev_pos, prev_msa_first_row, prev_pair, label_arrays] + return res + return res + + def pipeline(self, cfg, mixed_precision=True, seed=0): + """feature process pipeline""" + self.non_ensemble(cfg.common.distillation, cfg.common.replace_proportion, cfg.common.use_templates) + non_ensemble_data = vars(self).copy() + max_msa_clusters = cfg.eval.max_msa_clusters + if cfg.common.reduce_msa_clusters_by_max_templates: + max_msa_clusters = cfg.eval.max_msa_clusters - cfg.eval.max_templates + random_recycle = cfg.common.random_recycle + non_ensemble_data_copy = non_ensemble_data.copy() + protein = self.ensemble(non_ensemble_data_copy, + cfg.block_deletion.msa_fraction_per_block, + cfg.block_deletion.randomize_num_blocks, + cfg.block_deletion.num_blocks, + cfg.eval.keep_extra, + max_msa_clusters, + cfg.common.masked_msa.use_masked_msa, + cfg.common.masked_msa.uniform_prob, + cfg.common.masked_msa.profile_prob, + cfg.common.masked_msa.same_prob, + cfg.eval.masked_msa_replace_fraction, + cfg.common.msa_cluster_features, + cfg.common.max_extra_msa, + cfg.eval.crop_size, + cfg.eval.max_templates, + cfg.eval.subsample_templates, + cfg.eval.fixed_size, + seed, + random_recycle) + num_ensemble = cfg.eval.num_ensemble + num_recycle = cfg.common.num_recycle + if cfg.common.resample_msa_in_recycling: + # Separate batch per ensembling & recycling step. + num_ensemble *= num_recycle + result_array = {x: () for x in protein.keys()} + if num_ensemble > 1: + for _ in range(num_ensemble): + non_ensemble_data_copy = non_ensemble_data.copy() + data_t = self.ensemble(non_ensemble_data_copy, + cfg.block_deletion.msa_fraction_per_block, + cfg.block_deletion.randomize_num_blocks, + cfg.block_deletion.num_blocks, + cfg.eval.keep_extra, + max_msa_clusters, + cfg.common.masked_msa.use_masked_msa, + cfg.common.masked_msa.uniform_prob, + cfg.common.masked_msa.profile_prob, + cfg.common.masked_msa.same_prob, + cfg.eval.masked_msa_replace_fraction, + cfg.common.msa_cluster_features, + cfg.common.max_extra_msa, + cfg.eval.crop_size, + cfg.eval.max_templates, + cfg.eval.subsample_templates, + cfg.eval.fixed_size, + seed, + random_recycle) + for key in protein.keys(): + result_array[key] += (data_t[key][None],) + for key in protein.keys(): + result_array[key] = np.concatenate(result_array[key], axis=0) + else: + result_array = {key: protein[key][None] for key in protein.keys()} + features = {k: v for k, v in result_array.items() if v.dtype != 'O'} + extra_msa_length = cfg.common.max_extra_msa + for key in ["extra_msa", "extra_has_deletion", "extra_deletion_value", "extra_msa_mask"]: + features[key] = features[key][:, :extra_msa_length] + input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype', + 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', + 'template_pseudo_beta_mask', 'template_pseudo_beta', + 'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index'] + + dtype = np.float32 + if mixed_precision: + dtype = np.float16 + arrays = [features[key] for key in input_keys] + arrays = [array.astype(dtype) if array.dtype == "float64" else array for array in arrays] + arrays = [array.astype(dtype) if array.dtype == "float32" else array for array in arrays] + prev_pos = Tensor(np.zeros([cfg.eval.crop_size, 37, 3]).astype(dtype)) + prev_msa_first_row = Tensor(np.zeros([cfg.eval.crop_size, 256]).astype(dtype)) + prev_pair = Tensor(np.zeros([cfg.eval.crop_size, cfg.eval.crop_size, 128]).astype(dtype)) + res = [arrays, prev_pos, prev_msa_first_row, prev_pair] + res = self.process_res(features, res, dtype) + return res diff --git a/tests/st/mindsponge/test_megafold/data/protein_feature.py b/tests/st/mindsponge/test_megafold/data/protein_feature.py new file mode 100644 index 000000000..7272ae895 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/protein_feature.py @@ -0,0 +1,132 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +protein feature generation module. +""" + +import pickle +import os +import stat +import numpy as np + +from mindsponge.data.parsers import parse_fasta, parse_hhr, parse_a3m +from mindsponge.common import residue_constants +from data.templates import TemplateHitFeaturizer +from data.hhsearch import HHSearch + + +def make_msa_features(msas, deletion_matrices): + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError(f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append([residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) + + num_res = len(msas[0][0]) + num_alignments = len(int_msa) + features = {'deletion_matrix_int': np.array(deletion_matrix, dtype=np.int32), + 'msa': np.array(int_msa, dtype=np.int32), + 'num_alignments': np.array([num_alignments] * num_res, dtype=np.int32)} + return features + + +def make_sequence_features(sequence: str, description: str, num_res: int): + """Constructs a feature dict of sequence features.""" + features = {'aatype': residue_constants.sequence_to_onehot(sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True), + 'between_segment_residues': np.zeros((num_res,), dtype=np.int32), + 'domain_name': np.array([description.encode('utf-8')], dtype=np.object_), + 'residue_index': np.array(range(num_res), dtype=np.int32), + 'seq_length': np.array([num_res] * num_res, dtype=np.int32), + 'sequence': np.array([sequence.encode('utf-8')], dtype=np.object_)} + return features + + +class RawFeatureGenerator: + """Runs the alignment tools""" + + def __init__(self, + template_mmcif_dir, + max_template_date, + kalign_binary_path, + obsolete_pdbs_path, + hhsearch_binary_path, + pdb70_database_path, + result_path, + max_hits=20): + """Search the a3m info for a given FASTA file.""" + + self.template_mmcif_dir = template_mmcif_dir + self.max_template_date = max_template_date + self.kalign_binary_path = kalign_binary_path + self.obsolete_pdbs_path = obsolete_pdbs_path + self.hhsearch_binary_path = hhsearch_binary_path + self.pdb70_database_path = pdb70_database_path + self.result_path = result_path + self.max_hits = max_hits + self.hhsearch_pdb70_runner = HHSearch(binary_path=hhsearch_binary_path, databases=[pdb70_database_path]) + self.result_path = result_path + + def raw_feature_generate(self, raw_feature_path, fasta_path, a3m_lines): + """protein raw feature generation""" + template_featurizer = TemplateHitFeaturizer(mmcif_dir=self.template_mmcif_dir, + max_template_date=self.max_template_date, + max_hits=self.max_hits, + kalign_binary_path=self.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=self.obsolete_pdbs_path) + with open(fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError(f'More than one input sequence found in {fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + + num_res = len(input_sequence) + + hhsearch_result = self.hhsearch_pdb70_runner.query(a3m_lines) + hhsearch_hits = parse_hhr(hhsearch_result) + + msas, deletion_matrices = parse_a3m(a3m_lines) + templates_result = template_featurizer.get_templates( + query_sequence=input_sequence, + query_pdb_code=None, + query_release_date=None, + hhr_hits=hhsearch_hits) + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + msa_features = make_msa_features(msas=(msas,), deletion_matrices=(deletion_matrices,)) + + feature_dict = {**sequence_features, **msa_features, **templates_result.features} + os.makedirs(raw_feature_path, exist_ok=True) + pkl_path = os.path.join(raw_feature_path, fasta_path.split(".fasta")[0]+'.pkl') + with os.fdopen(os.open(pkl_path, os.O_RDWR|os.O_CREAT, stat.S_IRWXU), 'wb') as f: + pickle.dump(feature_dict, f, protocol=4) + return feature_dict diff --git a/tests/st/mindsponge/test_megafold/data/templates.py b/tests/st/mindsponge/test_megafold/data/templates.py new file mode 100644 index 000000000..5bf51fb31 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/templates.py @@ -0,0 +1,920 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''template''' +import datetime +import glob +import os +import re +import dataclasses +from typing import Any, Mapping, Optional, Sequence, Tuple +from absl import logging +import numpy as np + +from mindsponge.data.parsers import parse_mmcif, parse_a3m +from mindsponge.common import residue_constants +from data.kalign import Kalign + + +class Error(Exception): + """Base class for exceptions.""" + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class PdbIdError(PrefilterError): + """An error indicating that the hit PDB ID was identical to the query.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + 'template_aatype': np.float32, + 'template_all_atom_masks': np.float32, + 'template_all_atom_positions': np.float32, + 'template_domain_names': np.object, + 'template_e_value': np.float32, + 'template_neff': np.float32, + 'template_prob_true': np.float32, + 'template_release_date': np.object, + 'template_score': np.float32, + 'template_similarity': np.float32, + 'template_sequence': np.object, + 'template_sum_probs': np.float32, + 'template_confidence_scores': np.int64 +} + + +def _get_pdb_id_and_chain(hit): + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) + if not id_match: + raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}') + pdb_id, chain_id = id_match.group(0).split('_') + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime]) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + if release_date_cutoff is None: + raise ValueError('The release_date_cutoff must not be None.') + if pdb_id in release_dates: + return release_dates[pdb_id] > release_date_cutoff + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: + """Parses the data file from PDB that lists which PDB ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # We skip obsolete entries that don't contain a mapping to a new entry. + if line.startswith('OBSLTE') and len(line) > 30: + # Format: Date From To + # 'OBSLTE 31-JUL-94 116L 216L' + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + return result + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + if path.endswith('txt'): + release_dates = {} + with open(path, 'r') as f: + for line in f: + pdb_id, date = line.split(':') + date = date.strip() + # Python 3.6 doesn't have datetime.date.fromisoformat() which is about 90x faster than strptime. + # However, splitting the string manually is about 10x faster than strptime. + release_dates[pdb_id.strip()] = \ + datetime.datetime(year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10])) + return release_dates + raise ValueError('Invalid format of the release date file %s.' % path) + + +def _assess_hhsearch_hit( + hit, + hit_pdb_code, + query_sequence, + query_pdb_code, + release_dates, + release_date_cutoff, + max_subsequence_ratio=0.95, + min_align_ratio=0.1): + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + query_pdb_code: 4 letter pdb code of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + PdbIdError: If the hit PDB ID was identical to the query. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace('-', '') + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = (template_sequence in query_sequence and length_ratio > max_subsequence_ratio) + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date ({release_date_cutoff}).') + + if query_pdb_code is not None: + if query_pdb_code.lower() == hit_pdb_code.lower(): + raise PdbIdError('PDB code identical to Query PDB code.') + + if align_ratio <= min_align_ratio: + raise AlignRatioError(f'Proportion of residues aligned to query too small. Align ratio: {align_ratio}.') + + if duplicate: + raise DuplicateError(f'Template is an exact subsequence of query with large coverage.' + f' Length ratio: {length_ratio}.') + + if len(template_sequence) < 10: + raise LengthError(f'Template too short. Length: {len(template_sequence)}.') + + return True + + +def _find_template_in_pdb(template_chain_id, template_sequence, mmcif_object): + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the + # (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found an exact template match %s_%s.', pdb_id, template_chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info(f'Found a sequence-only match {pdb_id}_{chain_id}.') + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups + # limit. + regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] + regex = re.compile(''.join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info(f'Found a fuzzy sequence-only match {pdb_id}_{chain_id}.') + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + 'Could not find the template sequence in %s_%s. Template sequence: %s, ' + 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, mmcif_object.chain_to_seqres)) + + +def _realign_pdb_template_to_query( + old_template_sequence, + template_chain_id, + mmcif_object, + old_mapping, + kalign_binary_path): + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get(template_chain_id, '') + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info(f'Could not find {template_chain_id} in {mmcif_object.file_id}, but there is only 1 sequence,' + f' so using that one.') + new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0] + else: + raise QueryToTemplateAlignError( + f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' + 'If there are no mmCIF parsing errors, it is possible it was not a ' + 'protein chain.') + + try: + (old_aligned_template, new_aligned_template), _ = \ + parse_a3m(aligner.align([old_template_sequence, new_template_sequence])) + except Exception as e: + raise QueryToTemplateAlignError( + 'Could not align old template %s to template %s (%s_%s). Error: %s' % + (old_template_sequence, + new_template_sequence, + mmcif_object.file_id, + template_chain_id, + str(e))) + + logging.info(f'Old aligned template: {old_aligned_template}\nNew aligned template: {new_aligned_template}') + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip(old_aligned_template, new_aligned_template): + if old_template_aa != '-': + old_template_index += 1 + if new_template_aa != '-': + new_template_index += 1 + if old_template_aa != '-' and new_template_aa != '-': + old_to_new_template_mapping[old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if float(num_same) / min(len(old_template_sequence), len(new_template_sequence)) < 0.9: + raise QueryToTemplateAlignError( + 'Insufficient similarity of the sequence in the database: %s to the ' + 'actual sequence in the mmCIF file %s_%s: %s. We require at least ' + '90 %% similarity wrt to the shorter of the sequences. This is not a ' + 'problem unless you think this is a template that should be included.' % + (old_template_sequence, mmcif_object.file_id, template_chain_id, + new_template_sequence)) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[query_index] = (old_to_new_template_mapping.get(old_template_index, -1)) + + new_template_sequence = new_template_sequence.replace('-', '') + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances(all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order['CA'] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError('The distance between residues %d and %d is %f > limit %f.' % + (i, i + 1, distance, max_ca_ca_distance)) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object, + auth_chain_id, + max_ca_ca_distance) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) + + relevant_chains = [c for c in mmcif_object.structure.get_chains() if c.id == auth_chain_id] + if len(relevant_chains) != 1: + raise MultipleChainsError(f'Expected exactly one chain in structure with id {auth_chain_id}.') + chain = relevant_chains[0] + + all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3]) + all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num], dtype=np.int64) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index] + if not res_at_position.is_missing: + res = chain[(res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code)] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': + # Put the coordinates of the selenium atom in the sulphur + # column. + pos[residue_constants.atom_order['SD']] = [x, y, z] + mask[residue_constants.atom_order['SD']] = 1.0 + + all_positions[res_index] = pos + all_positions_mask[res_index] = mask + _check_residue_distances(all_positions, all_positions_mask, max_ca_ca_distance) + return all_positions, all_positions_mask + + +def _extract_template_features( + mmcif_object, + pdb_id, + mapping, + template_sequence, + query_sequence, + template_chain_id, + confidence_scores, + kalign_binary_path): + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Note that we only extract at most 500 templates because of HHSearch settings. + + We set missing/invalid confidence scores to the default value of -1. + Note: We now have 4 types of confidence scores: + 1. Valid scores + 2. Invalid scores of residues not in both the query sequence and template + sequence + 3. Missing scores because we don't have the secondary structure, and HHAlign + doesn't produce the posterior probabilities in this case. + 4. Missing scores because of a different template sequence in PDB70, + invalidating the previously computed confidence scores. (Though in theory + HHAlign can be run on these to recompute the correct confidence scores). + We handle invalid and missing scores by setting them to -1, but consider + adding masks for the different types. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + confidence_scores: String containing per-residue confidence scores, where + each character represents the *TRUNCATED* posterior probability that the + corresponding template residue is correctly aligned with the query + residue, given the database match is correct (0 corresponds approximately + to 0-10%, 9 to 90-100%). + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id)) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = (f'The exact sequence {template_sequence} was not found in ' + f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.') + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path) + logging.info(f'Sequence in {pdb_id}_{chain_id}: {template_sequence} successfully realigned to {seqres}') + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + # Confidence scores were based on the previous sequence, so they are + # invalid + confidence_scores = None + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions(mmcif_object, chain_id, max_ca_ca_distance=150.0) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError(f'Could not get atom data ({pdb_id}_{chain_id}): {str(ex)}') + + all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0]) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + output_confidence_scores = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append(np.zeros((residue_constants.atom_type_num, 3))) + templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num)) + output_templates_sequence.append('-') + output_confidence_scores.append(-1) + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + if confidence_scores and confidence_scores[v] != ' ': + output_confidence_scores[k] = int(confidence_scores[v]) + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, + # O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError('Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' % + (pdb_id, chain_id, min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset)) + + output_templates_sequence = ''.join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + + return ( + {'template_all_atom_positions': np.array(templates_all_atom_positions), + 'template_all_atom_masks': np.array(templates_all_atom_masks), + 'template_sequence': output_templates_sequence.encode(), + 'template_aatype': np.array(templates_aatype), + 'template_confidence_scores': np.array(output_confidence_scores), + 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), + 'template_release_date': mmcif_object.header['release_date'].encode()}, + warning) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace('-', '') + hit_sequence = hit_sequence.replace('-', '') + hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence) + + # Index of -1 used for gap characters. Subtract the min index ignoring + # gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [x - min_idx if x > - 1 else - 1 for x in indices_query] + + # Zip the corrected indices, ignore case where both seqs have gap + # characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if (q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(original_query_sequence)): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +def _process_single_hit( + query_sequence, + query_pdb_code, + hit, + mmcif_dir, + max_template_date, + release_dates, + obsolete_pdbs, + kalign_binary_path, + strict_error_check): + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being + # obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + release_dates=release_dates, + release_date_cutoff=max_template_date) + except PrefilterError as e: + msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' + logging.info('%s: %s', query_pdb_code, msg) + if strict_error_check and isinstance(e, (DateError, PdbIdError, DuplicateError)): + # In strict mode we treat some prefilter cases as errors. + return SingleHitResult(features=None, error=msg, warning=None) + + return SingleHitResult(features=None, error=None, warning=None) + + mapping = _build_query_to_hit_index_mapping( + hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query, query_sequence) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace('-', '') + confidence_scores = ''.join([cs for t, cs in zip(hit.hit_sequence, hit.confidence_scores) if t != '-']) + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') + if not os.path.exists(cif_path): + cif_path = os.path.join(mmcif_dir, hit_pdb_code.upper() + '.cif') + logging.info('Reading PDB entry from %s. Query: %s, template: %s', cif_path, query_sequence, template_sequence) + # Fail if we can't find the mmCIF file. + with open(cif_path, 'r') as cif_file: + cif_string = cif_file.read() + + parsing_result = parse_mmcif(file_id=hit_pdb_code, mmcif_string=cif_string) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime(parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') + if hit_release_date > max_template_date: + error = ('Template %s date (%s) > max template date (%s).' % + (hit_pdb_code, hit_release_date, max_template_date)) + if strict_error_check: + return SingleHitResult(features=None, error=error, warning=None) + logging.warning(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + confidence_scores=confidence_scores, + kalign_binary_path=kalign_binary_path) + features['template_e_value'] = [hit.e_value] + features['template_sum_probs'] = [hit.sum_probs] + features['template_prob_true'] = [hit.prob_true] + features['template_score'] = [hit.score] + features['template_neff'] = [hit.neff] + features['template_similarity'] = [hit.similarity] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult(features=features, error=None, warning=realign_warning) + except (NoChainsError, NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % (hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors)) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % (hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors)) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer: + """A class for turning hhr hits to template features.""" + + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str], + obsolete_pdbs_path: Optional[str], + strict_error_check: bool = False): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): + logging.error('Could not find CIFs in %s', self._mmcif_dir) + raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') + + try: + self._max_template_date = datetime.datetime.strptime(max_template_date, '%Y-%m-%d') + except ValueError: + raise ValueError('max_template_date must be set and have format YYYY-MM-DD.') + self._max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info('Using precomputed release dates %s.', release_dates_path) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + def get_templates( + self, + query_sequence, + query_pdb_code, + query_release_date, + hhr_hits): + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_pdb_code) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + # Always use a max_template_date. Set to query_release_date minus 60 days + # if that's earlier. + template_cutoff_date = self._max_template_date + if query_release_date: + delta = datetime.timedelta(days=60) + if query_release_date - delta < template_cutoff_date: + template_cutoff_date = query_release_date - delta + assert template_cutoff_date < query_release_date + assert template_cutoff_date <= self._max_template_date + + num_hits = 0 + errors = [] + warnings = [] + + for hit in sorted(hhr_hits, key=lambda x: x.sum_probs, reverse=True): + # We got all the templates we wanted, stop processing HHSearch + # hits. + if num_hits >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=template_cutoff_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info('Skipped invalid hit %s, error: %s, warning: %s', hit.name, result.error, result.warning) + else: + # Increment the hit counter, since we got features out of this + # hit. + num_hits += 1 + for k in template_features: + template_features.get(k).append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack(template_features.get(name), + axis=0).astype(TEMPLATE_FEATURES.get(name)) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array([], dtype=TEMPLATE_FEATURES.get(name)) + + return TemplateSearchResult(features=template_features, errors=errors, warnings=warnings) diff --git a/tests/st/mindsponge/test_megafold/data/utils.py b/tests/st/mindsponge/test_megafold/data/utils.py new file mode 100644 index 000000000..6e34e3e43 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/data/utils.py @@ -0,0 +1,40 @@ +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +utils module used for tmpdir generation. +""" + +import contextlib +import tempfile +import shutil + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: str): + """Context manager that deletes a temporary directory on exit. + for example: + with tmpdir_manager(base_dir='/tmp') as tmp_dir: + test_file = os.path.join(tmp_dir, 'input.fasta') + with open(test_file, "w") as f: + f.write("this is a test. \n") + print("exit") + this would create a tmp data directory and when finished the main process of writing "this is a test. \n" into + the tmp file,(after print "exit"), the system would destroy the previous tmp dir + """ + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/st/mindsponge/test_megafold/model/__init__.py b/tests/st/mindsponge/test_megafold/model/__init__.py new file mode 100644 index 000000000..86ba22bf8 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/model/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''init''' +from .fold import MegaFold, compute_confidence diff --git a/tests/st/mindsponge/test_megafold/model/fold.py b/tests/st/mindsponge/test_megafold/model/fold.py new file mode 100644 index 000000000..f11f18ce1 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/model/fold.py @@ -0,0 +1,304 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore import Parameter +import mindsponge.common.residue_constants as residue_constants +from mindsponge.common.utils import dgram_from_positions, pseudo_beta_fn, atom37_to_torsion_angles +from mindsponge.data.data_transform import get_chi_atom_pos_indices +from mindsponge.cell.initializer import lecun_init +from module.template_embedding import TemplateEmbedding +from module.evoformer import Evoformer +from module.structure import StructureModule +from module.head import DistogramHead, ExperimentallyResolvedHead, MaskedMsaHead, \ + PredictedLDDTHead, PredictedAlignedErrorHead +from scipy.special import softmax + + +def caculate_constant_array(seq_length): + '''constant array''' + chi_atom_indices = np.array(get_chi_atom_pos_indices()).astype(np.int32) + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = np.array(chi_angles_mask).astype(np.float32) + mirror_psi_mask = np.float32(np.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]) + chi_pi_periodic = np.float32(np.array(residue_constants.chi_pi_periodic)) + + indices0 = np.arange(4).reshape((-1, 1, 1, 1, 1)).astype("int32") # 4 batch + indices0 = indices0.repeat(seq_length, axis=1) # seq_length sequence length + indices0 = indices0.repeat(4, axis=2) # 4 chis + indices0 = indices0.repeat(4, axis=3) # 4 atoms + + indices1 = np.arange(seq_length).reshape((1, -1, 1, 1, 1)).astype("int32") + indices1 = indices1.repeat(4, axis=0) + indices1 = indices1.repeat(4, axis=2) + indices1 = indices1.repeat(4, axis=3) + + constant_array = [chi_atom_indices, chi_angles_mask, mirror_psi_mask, chi_pi_periodic, indices0, indices1] + constant_array = [Tensor(val) for val in constant_array] + return constant_array + + +def compute_confidence(predicted_lddt_logits): + """compute confidence""" + + num_bins = predicted_lddt_logits.shape[-1] + bin_width = 1 / num_bins + start_n = bin_width / 2 + plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width) + confidence = np.mean(plddt) + return confidence + + +def compute_plddt(logits, start_n, bin_width): + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width) + probs = softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +class MegaFold(nn.Cell): + """MegaFold""" + + def __init__(self, config, mixed_precision=True): + super(MegaFold, self).__init__() + + self.cfg = config + + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.is_training = self.cfg.is_training + self.recycle_pos = self.cfg.recycle_pos + self.recycle_features = self.cfg.recycle_features + self.max_relative_feature = self.cfg.max_relative_feature + self.num_bins = self.cfg.prev_pos.num_bins + self.min_bin = self.cfg.prev_pos.min_bin + self.max_bin = self.cfg.prev_pos.max_bin + self.template_enabled = self.cfg.template.enabled + self.template_embed_torsion_angles = self.cfg.template.embed_torsion_angles + self.extra_msa_stack_num = self.cfg.evoformer.extra_msa_stack_num + self.msa_stack_num = self.cfg.evoformer.msa_stack_num + self.chi_atom_indices, self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, \ + self.indices0, self.indices1 = caculate_constant_array(self.cfg.seq_length) + + self.preprocess_1d = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.preprocess_msa = nn.Dense(self.cfg.common.msa_feat_dim, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.common.msa_feat_dim)) + self.left_single = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.right_single = nn.Dense(self.cfg.common.target_feat_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.target_feat_dim)) + self.prev_pos_linear = nn.Dense(self.cfg.common.dgram_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.dgram_dim)) + self.pair_activations = nn.Dense(self.cfg.common.pair_in_dim, self.cfg.pair_channel, + weight_init=lecun_init(self.cfg.common.pair_in_dim)) + self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1) + self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1) + self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5) + self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5) + self.one_hot = nn.OneHot(depth=self.cfg.max_relative_feature * 2 + 1, axis=-1) + self.extra_msa_activations = nn.Dense(25, self.cfg.extra_msa_channel, weight_init=lecun_init(25)) + self.template_embedding = TemplateEmbedding(self.cfg, self.cfg.seq_length, mixed_precision) + + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.template_single_embedding = nn.Dense(57, self.cfg.msa_channel, + weight_init= + lecun_init(57, initializer_name='relu')) + self.template_projection = nn.Dense(self.cfg.msa_channel, self.cfg.msa_channel, + weight_init=lecun_init(self.cfg.msa_channel, + initializer_name='relu')) + self.relu = nn.ReLU() + self.single_activations = nn.Dense(self.cfg.msa_channel, self.cfg.seq_channel, + weight_init=lecun_init(self.cfg.msa_channel)) + extra_msa_stack = nn.CellList() + for _ in range(self.extra_msa_stack_num): + extra_msa_block = Evoformer(self.cfg, + msa_act_dim=64, + pair_act_dim=128, + is_extra_msa=True, + batch_size=None, + mixed_precision=mixed_precision) + if self.is_training: + extra_msa_block.recompute() + extra_msa_stack.append(extra_msa_block) + self.extra_msa_stack = extra_msa_stack + if self.is_training: + msa_stack = nn.CellList() + for _ in range(self.msa_stack_num): + msa_block = Evoformer(self.cfg, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=None, + mixed_precision=mixed_precision) + msa_block.recompute() + msa_stack.append(msa_block) + self.msa_stack = msa_stack + + self.module_distogram = DistogramHead(self.cfg.heads.distogram, + self.cfg.pair_channel) + self.module_exp_resolved = ExperimentallyResolvedHead(self.cfg.seq_channel) + self.module_mask = MaskedMsaHead(self.cfg.heads.masked_msa, + self.cfg.msa_channel) + self.aligned_error = PredictedAlignedErrorHead(self.cfg.heads.predicted_aligned_error, + self.cfg.pair_channel) + else: + self.msa_stack = Evoformer(self.cfg, + msa_act_dim=256, + pair_act_dim=128, + is_extra_msa=False, + batch_size=self.msa_stack_num, + mixed_precision=mixed_precision) + self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) + self.evoformer_num_block_eval = Tensor(self.msa_stack_num, mstype.int32) + + self.structure_module = StructureModule(self.cfg, + self.cfg.seq_channel, + self.cfg.pair_channel, + mixed_precision) + + self.module_lddt = PredictedLDDTHead(self.cfg.heads.predicted_lddt, + self.cfg.seq_channel) + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair): + """construct""" + + preprocess_1d = self.preprocess_1d(target_feat) + preprocess_msa = self.preprocess_msa(msa_feat) + msa_activations = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + left_single = self.left_single(target_feat) + right_single = self.right_single(target_feat) + pair_activations = P.ExpandDims()(left_single, 1) + P.ExpandDims()(right_single, 0) + mask_2d = P.ExpandDims()(seq_mask, 1) * P.ExpandDims()(seq_mask, 0) + if self.recycle_pos: + prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None) + dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin, self._type) + pair_activations += self.prev_pos_linear(dgram) + + if self.recycle_features: + prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row) + msa_activations = mnp.concatenate( + (mnp.expand_dims(prev_msa_first_row + msa_activations[0, ...], 0), msa_activations[1:, ...]), 0) + pair_activations += self.prev_pair_norm(prev_pair) + + if self.max_relative_feature: + offset = P.ExpandDims()(residue_index, 1) - P.ExpandDims()(residue_index, 0) + rel_pos = self.one_hot(mnp.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature)) + pair_activations += self.pair_activations(rel_pos) + + template_pair_representation = 0 + if self.template_enabled: + template_pair_representation = self.template_embedding(pair_activations, template_aatype, + template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, + template_pseudo_beta, mask_2d) + pair_activations += template_pair_representation + msa_1hot = self.extra_msa_one_hot(extra_msa) + extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_deletion_value[..., None]), + axis=-1) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + extra_msa_mask_tmp = P.Transpose()(P.ExpandDims()(extra_msa_mask, -1), (2, 1, 0)) + extra_msa_norm = P.Transpose()(self.batch_matmul_trans_b(extra_msa_mask_tmp, extra_msa_mask_tmp), (1, 2, 0)) + for i in range(self.extra_msa_stack_num): + extra_msa_activations, pair_activations = \ + self.extra_msa_stack[i](extra_msa_activations, pair_activations, extra_msa_mask, extra_msa_norm, + mask_2d) + + if self.template_enabled and self.template_embed_torsion_angles: + num_templ, num_res = template_aatype.shape + aatype_one_hot = self.template_aatype_one_hot(template_aatype) + torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask = atom37_to_torsion_angles( + template_aatype, template_all_atom_positions, template_all_atom_masks, self.chi_atom_indices, + self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, self.indices0, self.indices1) + template_features = mnp.concatenate([aatype_one_hot, + mnp.reshape(torsion_angles_sin_cos, [num_templ, num_res, 14]), + mnp.reshape(alt_torsion_angles_sin_cos, [num_templ, num_res, 14]), + torsion_angles_mask], axis=-1) + template_activations = self.template_single_embedding(template_features) + template_activations = self.relu(template_activations) + template_activations = self.template_projection(template_activations) + msa_activations = mnp.concatenate([msa_activations, template_activations], axis=0) + torsion_angle_mask = torsion_angles_mask[:, :, 2] + msa_mask = mnp.concatenate([msa_mask, torsion_angle_mask], axis=0) + + msa_mask_tmp = P.Transpose()(P.ExpandDims()(msa_mask, -1), (2, 1, 0)) + msa_mask_norm = P.Transpose()(self.batch_matmul_trans_b(msa_mask_tmp, msa_mask_tmp), (1, 2, 0)) + if self.is_training: + for i in range(self.msa_stack_num): + msa_activations, pair_activations = self.msa_stack[i](msa_activations, pair_activations, msa_mask, + msa_mask_norm, mask_2d) + else: + self.idx_evoformer_block = self.idx_evoformer_block * 0 + while self.idx_evoformer_block < self.evoformer_num_block_eval: + msa_activations, pair_activations = self.msa_stack(msa_activations, + pair_activations, + msa_mask, + msa_mask_norm, + mask_2d, + self.idx_evoformer_block) + self.idx_evoformer_block += 1 + single_activations = self.single_activations(msa_activations[0]) + num_sequences = msa_feat.shape[0] + msa = msa_activations[:num_sequences, :, :] + msa_first_row = msa_activations[0] + + final_atom_positions, _, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj = \ + self.structure_module(single_activations, + pair_activations, + seq_mask, + aatype, + residx_atom37_to_atom14, + atom37_atom_exists) + predicted_lddt_logits = self.module_lddt(rp_structure_module) + if self.train_backward: + predicted_lddt_logits = self.module_lddt(rp_structure_module) + dist_logits, bin_edges = self.module_distogram(pair_activations) + experimentally_logits = self.module_exp_resolved(single_activations) + masked_logits = self.module_mask(msa) + aligned_error_logits, aligned_error_breaks = self.aligned_error(pair_activations) + return dist_logits, bin_edges, experimentally_logits, masked_logits, aligned_error_logits, \ + aligned_error_breaks, atom14_pred_positions, final_affines, angles_sin_cos_new, \ + predicted_lddt_logits, structure_traj, sidechain_frames, sidechain_atom_pos, \ + um_angles_sin_cos_new, final_atom_positions + final_atom_positions = P.Cast()(final_atom_positions, self._type) + prev_pos = final_atom_positions + prev_msa_first_row = msa_first_row + prev_pair = pair_activations + if self.is_training: + return prev_pos, prev_msa_first_row, prev_pair + return prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits diff --git a/tests/st/mindsponge/test_megafold/module/evoformer.py b/tests/st/mindsponge/test_megafold/module/evoformer.py new file mode 100644 index 000000000..44e838cb9 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/module/evoformer.py @@ -0,0 +1,131 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evoformer""" + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindsponge.cell import MSARowAttentionWithPairBias, Transition, OuterProductMean, \ + TriangleAttention, TriangleMultiplication, \ + MSAColumnGlobalAttention, MSAColumnAttention + + +class Evoformer(nn.Cell): + '''evoformer''' + + def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size, mixed_precision): + super(Evoformer, self).__init__() + if is_extra_msa: + self.slice_cfg = config.slice.extra_msa_stack + else: + self.slice_cfg = config.slice.msa_stack + self.config = config.evoformer + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + self.config.msa_row_attention_with_pair_bias.num_head, + msa_act_dim, + msa_act_dim, + self.config.msa_row_attention_with_pair_bias.gating, + msa_act_dim, + pair_act_dim, + batch_size, + self.slice_cfg.msa_row_attention_with_pair_bias, + mixed_precision) + + self.msa_transition = Transition(self.config.msa_transition.num_intermediate_factor, + msa_act_dim, + batch_size, + self.slice_cfg.msa_transition, + mixed_precision) + + self.outer_product_mean = OuterProductMean(self.config.outer_product_mean.num_outer_channel, + msa_act_dim, + pair_act_dim, + batch_size, + self.slice_cfg.outer_product_mean, + mixed_precision) + + self.triangle_attention_starting_node = TriangleAttention( + self.config.triangle_attention_starting_node.orientation, + self.config.triangle_attention_starting_node.num_head, + pair_act_dim, + pair_act_dim, + self.config.triangle_attention_starting_node.gating, + pair_act_dim, + batch_size, + self.slice_cfg.triangle_attention_starting_node, + mixed_precision) + + self.triangle_attention_ending_node = TriangleAttention(self.config.triangle_attention_ending_node.orientation, + self.config.triangle_attention_ending_node.num_head, + pair_act_dim, + pair_act_dim, + self.config.triangle_attention_ending_node.gating, + pair_act_dim, + batch_size, + self.slice_cfg.triangle_attention_ending_node, + mixed_precision) + + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + pair_act_dim, + batch_size, + self.slice_cfg.pair_transition, + mixed_precision) + + self.triangle_multiplication_outgoing = TriangleMultiplication( + self.config.triangle_multiplication_outgoing.num_intermediate_channel, + self.config.triangle_multiplication_outgoing.equation, + layer_norm_dim=pair_act_dim, + batch_size=batch_size, + mixed_precision=mixed_precision) + + self.triangle_multiplication_incoming = TriangleMultiplication( + self.config.triangle_multiplication_incoming.num_intermediate_channel, + self.config.triangle_multiplication_incoming.equation, + layer_norm_dim=pair_act_dim, + batch_size=batch_size, + mixed_precision=mixed_precision) + if is_extra_msa: + self.attn_mod = MSAColumnGlobalAttention(self.config.msa_column_attention.num_head, + self.config.msa_column_attention.gating, + msa_act_dim, + batch_size, + self.slice_cfg.msa_column_global_attention, + mixed_precision) + else: + self.attn_mod = MSAColumnAttention(self.config.msa_column_attention.num_head, + msa_act_dim, + msa_act_dim, + self.config.msa_column_attention.gating, + msa_act_dim, + batch_size, + self.slice_cfg.msa_column_attention, + mixed_precision) + + def construct(self, msa_act, pair_act, msa_mask, extra_msa_norm, pair_mask, index=None): + '''construct''' + msa_act = P.Add()(msa_act, self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, index)) + msa_act = P.Add()(msa_act, self.attn_mod(msa_act, msa_mask, index)) + msa_act = P.Add()(msa_act, self.msa_transition(msa_act, index)) + pair_act = P.Add()(pair_act, self.outer_product_mean(msa_act, msa_mask, extra_msa_norm, index)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_outgoing(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_multiplication_incoming(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_attention_starting_node(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.triangle_attention_ending_node(pair_act, pair_mask, index)) + pair_act = P.Add()(pair_act, self.pair_transition(pair_act, index)) + return msa_act, pair_act diff --git a/tests/st/mindsponge/test_megafold/module/fold_wrapcell.py b/tests/st/mindsponge/test_megafold/module/fold_wrapcell.py new file mode 100644 index 000000000..cdab2fae2 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/module/fold_wrapcell.py @@ -0,0 +1,160 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""warp cell""" + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import ops +from mindspore.context import ParallelMode +from mindspore.nn import DistributedGradReducer +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.parallel._utils import _get_device_num +from mindspore.parallel._utils import (_get_gradients_mean, _get_parallel_mode) +from module.loss_module import LossNet + +GRADIENT_CLIP_TYPE = 1 + +clip_grad = ops.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """_clip_grad""" + if clip_type not in (0, 1): + return grad + dt = ops.dtype(grad) + if clip_type == 0: + new_grad = ops.clip_by_value(grad, ops.cast(ops.tuple_to_array((-clip_value,)), dt), + ops.cast(ops.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt)) + return new_grad + + +grad_scale = C.MultitypeFuncGraph("grad_scale") + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + """tensor_grad_scale""" + return grad * ops.Reciprocal()(scale) + + +class TrainOneStepCell(nn.Cell): + """TrainOneStepCell""" + def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True, use_global_norm=True, + gradient_clip_value=1.0): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.optimizer = optimizer + self.weights = self.optimizer.parameters + self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.enable_clip_grad = enable_clip_grad + self.hyper_map = ops.HyperMap() + self.use_global_norm = use_global_norm + self.gradient_clip_value = gradient_clip_value + + self.reducer_flag = False + self.grad_reducer = F.identity + self.parallel_mode = _get_parallel_mode() + self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) + + def construct(self, *inputs): + """construct""" + if self.train_backward: + loss = self.network(*inputs) + loss, l_fape_side, l_fape_backbone, l_anglenorm, distogram_loss, masked_loss, predict_lddt_loss = loss + sens = F.fill(loss.dtype, loss.shape, self.sens) + sens1 = F.fill(l_fape_side.dtype, l_fape_side.shape, 0.0) + sens2 = F.fill(l_fape_backbone.dtype, l_fape_backbone.shape, 0.0) + sens3 = F.fill(l_anglenorm.dtype, l_anglenorm.shape, 0.0) + sens4 = F.fill(distogram_loss.dtype, distogram_loss.shape, 0.0) + sens5 = F.fill(masked_loss.dtype, masked_loss.shape, 0.0) + sens6 = F.fill(predict_lddt_loss.dtype, predict_lddt_loss.shape, 0.0) + + grads = self.grad(self.network, self.weights)(*inputs, (sens, sens1, sens2, sens3, sens4, sens5, sens6)) + grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads) + grads = self.grad_reducer(grads) + if self.enable_clip_grad: + if self.use_global_norm: + grads = C.clip_by_global_norm(grads, self.gradient_clip_value) + else: + grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, self.gradient_clip_value), grads) + + loss = F.depend(loss, self.optimizer(grads)) + return loss, l_fape_side, l_fape_backbone, l_anglenorm, distogram_loss, masked_loss, predict_lddt_loss + + out = self.network(*inputs) + return out + + +class WithLossCell(nn.Cell): + """WithLossCell""" + def __init__(self, backbone, config): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self.loss_net = LossNet(config).to_float(mstype.float32) + + def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, + template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, + residx_atom37_to_atom14, atom37_atom_exists, residue_index, + prev_pos, prev_msa_first_row, prev_pair, pseudo_beta_gt, pseudo_beta_mask_gt, + all_atom_mask_gt, true_msa, bert_mask, + residx_atom14_to_atom37, restype_atom14_bond_lower_bound, restype_atom14_bond_upper_bound, + atomtype_radius, backbone_affine_tensor, backbone_affine_mask, + atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, + atom14_atom_exists, atom14_alt_gt_exists, all_atom_positions, rigidgroups_gt_frames, + rigidgroups_gt_exists, rigidgroups_alt_gt_frames, torsion_angles_sin_cos_gt, use_clamped_fape, + filter_by_solution, chi_mask): + """construct""" + if self.train_backward: + dist_logits, bin_edges, experimentally_logits, masked_logits, aligned_error_logits, aligned_error_breaks, \ + atom14_pred_positions, final_affines, angles_sin_cos_new, predicted_lddt_logits, structure_traj, \ + sidechain_frames, sidechain_atom_pos, um_angles_sin_cos_new, final_atom_positions = \ + self._backbone(target_feat, msa_feat, msa_mask, seq_mask, aatype, template_aatype, + template_all_atom_masks, template_all_atom_positions, template_mask, + template_pseudo_beta_mask, template_pseudo_beta, extra_msa, extra_has_deletion, + extra_deletion_value, extra_msa_mask, residx_atom37_to_atom14, atom37_atom_exists, + residue_index, prev_pos, prev_msa_first_row, prev_pair) + out = self.loss_net(dist_logits, bin_edges, pseudo_beta_gt, pseudo_beta_mask_gt, + experimentally_logits, atom37_atom_exists, all_atom_mask_gt, true_msa, + masked_logits, bert_mask, atom14_pred_positions, residue_index, aatype, + residx_atom14_to_atom37, restype_atom14_bond_lower_bound, + restype_atom14_bond_upper_bound, seq_mask, atomtype_radius, final_affines, + aligned_error_breaks, aligned_error_logits, angles_sin_cos_new, + um_angles_sin_cos_new, backbone_affine_tensor, backbone_affine_mask, + atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, atom14_alt_gt_exists, + final_atom_positions, all_atom_positions, predicted_lddt_logits, + structure_traj, rigidgroups_gt_frames, rigidgroups_gt_exists, + rigidgroups_alt_gt_frames, + sidechain_frames, sidechain_atom_pos, torsion_angles_sin_cos_gt, + chi_mask, use_clamped_fape, filter_by_solution) + else: + out = self._backbone(target_feat, msa_feat, msa_mask, seq_mask, aatype, template_aatype, + template_all_atom_masks, template_all_atom_positions, template_mask, + template_pseudo_beta_mask, template_pseudo_beta, extra_msa, + extra_has_deletion, extra_deletion_value, extra_msa_mask, residx_atom37_to_atom14, + atom37_atom_exists, residue_index, prev_pos, prev_msa_first_row, prev_pair) + return out diff --git a/tests/st/mindsponge/test_megafold/module/head.py b/tests/st/mindsponge/test_megafold/module/head.py new file mode 100644 index 000000000..6fa353748 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/module/head.py @@ -0,0 +1,166 @@ +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindsponge.cell.initializer import lecun_init + + +class PredictedLDDTHead(nn.Cell): + """Head to predict the per-residue LDDT to be used as a confidence measure.""" + + def __init__(self, config, seq_channel): + super().__init__() + self.config = config + self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) + self.act_0 = nn.Dense(seq_channel, self.config.num_channels, + weight_init=lecun_init(seq_channel, initializer_name='relu') + ).to_float(mstype.float16) + self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, + weight_init=lecun_init(self.config.num_channels, initializer_name='relu') + ).to_float(mstype.float16) + self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' + ).to_float(mstype.float16) + self.relu = nn.ReLU() + + def construct(self, rp_structure_module): + """Builds ExperimentallyResolvedHead module.""" + act = rp_structure_module + act = self.input_layer_norm(act.astype(mstype.float32)) + act = self.act_0(act) + act = self.relu(act.astype(mstype.float32)) + act = self.act_1(act) + act = self.relu(act.astype(mstype.float32)) + logits = self.logits(act) + return logits + + +class DistogramHead(nn.Cell): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, config, pair_dim): + super().__init__() + self.config = config + self.half_logits = nn.Dense(pair_dim, self.config.num_bins, weight_init='zeros') + self.first_break = self.config.first_break + self.last_break = self.config.last_break + self.num_bins = self.config.num_bins + + def construct(self, pair): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. + """ + half_logits = self.half_logits(pair) + + logits = half_logits + mnp.swapaxes(half_logits, -2, -3) + breaks = mnp.linspace(self.first_break, self.last_break, self.num_bins - 1) + + return logits, breaks + + +class ExperimentallyResolvedHead(nn.Cell): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, seq_channel): + super().__init__() + self.logits = nn.Dense(seq_channel, 37, weight_init='zeros') + + def construct(self, single): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [N_res, c_s]. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = self.logits(single) + return logits + + +class MaskedMsaHead(nn.Cell): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + + def __init__(self, config, msa_channel): + super().__init__() + self.config = config + self.logits = nn.Dense(msa_channel, self.config.num_output, weight_init='zeros') + + def construct(self, msa): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [N_seq, N_res, c_m]. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + # del batch + logits = self.logits(msa) + return logits + + +class PredictedAlignedErrorHead(nn.Cell): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + + def __init__(self, config, pair_dim): + super().__init__() + self.config = config + self.num_bins = self.config.num_bins + self.max_error_bin = self.config.max_error_bin + self.logits = nn.Dense(pair_dim, self.num_bins, weight_init='zeros') + + def construct(self, pair): + """Builds PredictedAlignedErrorHead module. + + Arguments: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + + Returns: + * logits: logits for aligned error, shape [N_res, N_res, N_bins]. + * breaks: array containing bin breaks, shape [N_bins - 1]. + """ + logits = self.logits(pair) + breaks = mnp.linspace(0, self.max_error_bin, self.num_bins - 1) + return logits, breaks diff --git a/tests/st/mindsponge/test_megafold/module/loss_module.py b/tests/st/mindsponge/test_megafold/module/loss_module.py new file mode 100644 index 000000000..2e83952b8 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/module/loss_module.py @@ -0,0 +1,296 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss module""" + +import mindspore as ms +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindsponge.common import residue_constants +from mindsponge.common.geometry import invert_point, quaternion_from_tensor, vecs_expend_dims +from mindsponge.metrics.structure_violations import softmax_cross_entropy, sigmoid_cross_entropy, \ + get_structural_violations, compute_renamed_ground_truth, backbone, sidechain, supervised_chi, \ + local_distance_difference_test + + +class LossNet(nn.Cell): + """loss net""" + def __init__(self, config): + super(LossNet, self).__init__() + self.config = config + self.num_res = config.seq_length + self.num_bins = config.heads.distogram.num_bins + self.resolution = config.heads.resolution + self.distogram_weight = config.heads.distogram.weight + self.distogram_one_hot = nn.OneHot(depth=self.num_bins, axis=-1) + self.exp_min_resolution = config.heads.experimentally_resolved.min_resolution + self.exp_max_resolution = config.heads.experimentally_resolved.max_resolution + self.exp_res_filter_by_resolution = config.heads.experimentally_resolved.filter_by_resolution + self.experimentally_weight = config.heads.experimentally_resolved.weight + self.exp_res_mask = Tensor(1, ms.float32)\ + if not self.exp_res_filter_by_resolution or\ + (self.exp_min_resolution <= self.resolution <= self.exp_max_resolution) else Tensor(0, ms.float32) + + self.ael_min_resolution = config.heads.predicted_aligned_error.min_resolution + self.ael_max_resolution = config.heads.predicted_aligned_error.max_resolution + self.ael_res_filter_by_resolution = config.heads.predicted_aligned_error.filter_by_resolution + self.ael_res_mask = Tensor(1, ms.float32)\ + if not self.ael_res_filter_by_resolution or\ + (self.ael_min_resolution <= self.resolution <= self.ael_max_resolution) else Tensor(0, ms.float32) + self.aligned_one_hot = nn.OneHot(depth=config.heads.predicted_aligned_error.num_bins) + + self.plddt_min_resolution = config.heads.predicted_lddt.min_resolution + self.plddt_max_resolution = config.heads.predicted_lddt.max_resolution + self.plddt_res_filter_by_resolution = config.heads.predicted_lddt.filter_by_resolution + self.plddt_res_mask = Tensor(1, ms.float32)\ + if not self.plddt_res_filter_by_resolution or\ + (self.plddt_min_resolution <= self.resolution <= self.plddt_max_resolution) else Tensor(0, ms.float32) + self.plddt_weight = config.heads.predicted_lddt.weight + + self.masked_one_hot = nn.OneHot(depth=23, axis=-1) + self.masked_weight = config.heads.masked_msa.weight + self.sidechain_weight_frac = config.heads.structure_module.sidechain.weight_frac + self.angle_norm_weight = config.heads.structure_module.angle_norm_weight + self.chi_weight = config.heads.structure_module.chi_weight + self.chi_pi_periodic = mnp.asarray(residue_constants.chi_pi_periodic, ms.float32) + + self.violation_tolerance_factor = config.heads.structure_module.violation_tolerance_factor + self.clash_overlap_tolerance = config.heads.structure_module.clash_overlap_tolerance + self.sidechain_atom_clamp_distance = config.heads.structure_module.sidechain.atom_clamp_distance + self.sidechain_length_scale = config.heads.structure_module.sidechain.length_scale + self.fape_clamp_distance = config.heads.structure_module.fape.clamp_distance + self.fape_loss_unit_distance = config.heads.structure_module.fape.loss_unit_distance + self.predicted_lddt_num_bins = config.heads.predicted_lddt.num_bins + self.c_one_hot = nn.OneHot(depth=14) + self.n_one_hot = nn.OneHot(depth=14) + self.zeros = Tensor(0, ms.int32) + self.twos = Tensor(2, ms.int32) + self.dists_mask_i = mnp.eye(14, 14) + self.cys_sg_idx = Tensor(5, ms.int32) + + def distogram_loss(self, logits, bin_edges, pseudo_beta, pseudo_beta_mask): + """Log loss of a distogram.""" + positions = pseudo_beta + mask = pseudo_beta_mask + + sq_breaks = mnp.square(bin_edges) + dist_t = mnp.square(mnp.expand_dims(positions, axis=-2) - mnp.expand_dims(positions, axis=-3)) + dist2 = P.ReduceSum(True)(dist_t.astype(ms.float32), -1) + aa = (dist2 > sq_breaks).astype(ms.float32) + + true_bins = P.ReduceSum()(aa, -1) + true_bins = true_bins.astype(ms.int32) + errors = softmax_cross_entropy(labels=self.distogram_one_hot(true_bins), logits=logits) + square_mask = mnp.expand_dims(mask, axis=-2) * mnp.expand_dims(mask, axis=-1) + avg_error = (P.ReduceSum()(errors * square_mask, (-2, -1)) / + (1e-6 + P.ReduceSum()(square_mask.astype(ms.float32), (-2, -1)))) + + dist2 = dist2[..., 0] + loss = avg_error + true_dist = mnp.sqrt(1e-6 + dist2) + + return loss, true_dist + + def experimentally_loss(self, experimentally_logits, atom37_atom_exists, all_atom_mask, filter_by_solution): + """experimentally_loss""" + logits = experimentally_logits + + # Does the atom appear in the amino acid? + atom_exists = atom37_atom_exists + # Is the atom resolved in the experiment? Subset of atom_exists, + # *except for OXT* + all_atom_mask = all_atom_mask.astype(mnp.float32) + + xent = sigmoid_cross_entropy(logits, all_atom_mask) + loss = P.ReduceSum()(xent * atom_exists) / (1e-8 + P.ReduceSum()(atom_exists.astype(ms.float32))) + loss = loss * filter_by_solution + loss *= self.exp_res_mask + return loss + + def masked_head_loss(self, true_msa, logits, bert_mask): + """masked_head_loss""" + errors = softmax_cross_entropy(logits=logits, labels=self.masked_one_hot(true_msa)) + loss = (P.ReduceSum()(errors * bert_mask, (-2, -1)) / + (1e-8 + P.ReduceSum()(bert_mask.astype(ms.float32), (-2, -1)))) + return loss + + # todo + def structure_loss(self, atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, final_atom14_positions, atom14_alt_gt_exists, + residue_index, aatype, residx_atom14_to_atom37, lower_bound, upper_bound, seq_mask, + atomtype_radius, angles_sin_cos, um_angles_sin_cos, traj, backbone_affine_tensor, + backbone_affine_mask, rigidgroups_gt_frames, rigidgroups_gt_exists, rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, use_clamped_fape): + """structure_loss""" + atom14_pred_positions = final_atom14_positions + # Compute renaming and violations. + alt_naming_is_better, renamed_atom14_gt_positions, renamed_atom14_gt_exists = \ + compute_renamed_ground_truth(atom14_gt_positions, + atom14_alt_gt_positions, + atom14_atom_is_ambiguous, + atom14_gt_exists, + atom14_pred_positions, + atom14_alt_gt_exists) + (bonds_c_n_loss_mean, angles_ca_c_n_loss_mean, angles_c_n_ca_loss_mean, _, + _, _, clashes_per_atom_loss_sum, _, per_atom_loss_sum, _, _, _) = \ + get_structural_violations(atom14_atom_exists, residue_index, aatype, residx_atom14_to_atom37, + atom14_pred_positions, self.violation_tolerance_factor, + self.clash_overlap_tolerance, lower_bound, upper_bound, atomtype_radius, + self.c_one_hot(self.twos), self.n_one_hot(self.zeros), self.dists_mask_i, + self.cys_sg_idx) + num_atoms = P.ReduceSum()(atom14_atom_exists.astype(ms.float32)) + structure_violation_loss = bonds_c_n_loss_mean + angles_ca_c_n_loss_mean + angles_c_n_ca_loss_mean + \ + P.ReduceSum()(clashes_per_atom_loss_sum + per_atom_loss_sum) / (1e-6 + num_atoms) + + # from structure module result + _, fape_loss, no_clamp = backbone(traj, backbone_affine_tensor, backbone_affine_mask, + self.fape_clamp_distance, self.fape_loss_unit_distance, use_clamped_fape) + + _, loss_sidechain = sidechain(alt_naming_is_better, rigidgroups_gt_frames, rigidgroups_alt_gt_frames, + rigidgroups_gt_exists, renamed_atom14_gt_positions, renamed_atom14_gt_exists, + self.sidechain_atom_clamp_distance, self.sidechain_length_scale, pred_frames, + pred_positions) + angle_norm_loss = supervised_chi(seq_mask, aatype, sin_cos_true_chi, torsion_angle_mask, + angles_sin_cos, um_angles_sin_cos, self.chi_weight, + self.angle_norm_weight, self.chi_pi_periodic) + return fape_loss, loss_sidechain, angle_norm_loss, structure_violation_loss, no_clamp + + def predicted_lddt_loss(self, final_atom_positions, all_atom_positions, all_atom_mask, predicted_lddt_logits, + filter_by_solution): + """predicted_lddt_loss""" + pred_all_atom_pos = final_atom_positions + true_all_atom_pos = all_atom_positions + lddt_ca = local_distance_difference_test( + predicted_points=pred_all_atom_pos[None, :, 1, :], + true_points=true_all_atom_pos[None, :, 1, :], + true_points_mask=all_atom_mask[None, :, 1:2].astype(mnp.float32), + cutoff=15., + per_residue=True)[0] + + lddt_ca = F.stop_gradient(lddt_ca) + + bin_index = mnp.floor(lddt_ca * self.predicted_lddt_num_bins).astype(mnp.int32) + + # protect against out of range for lddt_ca == 1 + bin_index = mnp.minimum(bin_index, self.predicted_lddt_num_bins - 1) + lddt_ca_one_hot = nn.OneHot(depth=self.predicted_lddt_num_bins)(bin_index) + + logits = predicted_lddt_logits + errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) + + mask_ca = all_atom_mask[:, 1] + mask_ca = mask_ca.astype(mnp.float32) + loss = P.ReduceSum()(errors * mask_ca) / P.ReduceSum()(P.ReduceSum()(mask_ca) + 1e-8) + loss = loss * filter_by_solution + loss *= self.plddt_res_mask + + return loss + + def aligned_error_loss(self, final_affines, backbone_affine_tensor, backbone_affine_mask, pae_breaks, pae_logits, + filter_by_solution): + """aligned_error_loss""" + # Shape (num_res, 7) predict affine + _, rotation_pd, translation_pd = quaternion_from_tensor(final_affines) + translation_point_pd = vecs_expend_dims(translation_pd, -2) + rotation_pd_tensor = rotation_pd + # Shape (num_res, 7) true affine + _, rotation_gt, translation_gt = quaternion_from_tensor(backbone_affine_tensor) + translation_point_tr = vecs_expend_dims(translation_gt, -2) + rotation_gt_tensor = rotation_gt + mask = backbone_affine_mask + square_mask = (mask[:, None] * mask[None, :]).astype(ms.float32) + breaks = pae_breaks + logits = pae_logits + + local_frames_pd = invert_point(translation_point_pd, rotation_pd_tensor, translation_pd, extra_dims=1) + local_frames_gt = invert_point(translation_point_tr, rotation_gt_tensor, translation_gt, extra_dims=1) + # todo to be checked + error_dist2 = mnp.square(local_frames_pd[0] - local_frames_gt[0]) + \ + mnp.square(local_frames_pd[1] - local_frames_gt[1]) + \ + mnp.square(local_frames_pd[2] - local_frames_gt[2]) + # # Compute the squared error for each alignment. + sq_breaks = mnp.square(breaks) + true_bins = P.ReduceSum()((error_dist2[..., None] > sq_breaks).astype(mnp.float32), -1) + + errors = softmax_cross_entropy(labels=self.aligned_one_hot(true_bins.astype(ms.int32)), logits=logits) + + loss = (P.ReduceSum()(errors * square_mask, (-2, -1)) / + (1e-8 + P.ReduceSum()(square_mask, (-2, -1)))) + loss = loss * filter_by_solution + loss *= self.ael_res_mask + + return loss + + def rmsd_loss(self, predicted_atom_positions, label_atom_positions, pseudo_beta_mask_2d): + """rmsd_loss""" + dist1 = P.Sqrt()((P.Square()(predicted_atom_positions[None] - + predicted_atom_positions[:, None])).sum(-1) + 1e-8) + dist2 = P.Sqrt()((P.Square()(label_atom_positions[None] - label_atom_positions[:, None])).sum(-1) + 1e-8) + return P.Sqrt()((P.Square()(dist1 - dist2) * pseudo_beta_mask_2d).mean() + 1e-8) + + def construct(self, distogram_logits, bin_edges, pseudo_beta, pseudo_beta_mask, experimentally_logits, + atom37_atom_exists, all_atom_mask, true_msa, masked_logits, bert_mask, + final_atom14_positions, residue_index, aatype, residx_atom14_to_atom37, lower_bound, upper_bound, + seq_mask, atomtype_radius, final_affines, pae_breaks, pae_logits, angles_sin_cos, + um_angles_sin_cos, backbone_affine_tensor, backbone_affine_mask, atom14_gt_positions, + atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, atom14_atom_exists, + atom14_alt_gt_exists, final_atom_positions, all_atom_positions, predicted_lddt_logits, traj, + rigidgroups_gt_frames, rigidgroups_gt_exists, rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, use_clamped_fape, + filter_by_solution): + """construct""" + distogram_loss, _ = self.distogram_loss(distogram_logits, bin_edges, pseudo_beta, pseudo_beta_mask) + distogram_loss = distogram_loss * self.distogram_weight + + masked_loss = self.masked_head_loss(true_msa, masked_logits, bert_mask) + masked_loss = self.masked_weight * masked_loss + + fape_loss, loss_sidechain, angle_norm_loss, _, _ = \ + self.structure_loss(atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, + atom14_gt_exists, atom14_atom_exists, final_atom14_positions, + atom14_alt_gt_exists, residue_index, aatype, residx_atom14_to_atom37, + lower_bound, upper_bound, seq_mask, atomtype_radius, angles_sin_cos, + um_angles_sin_cos, traj, backbone_affine_tensor, + backbone_affine_mask, rigidgroups_gt_frames, rigidgroups_gt_exists, + rigidgroups_alt_gt_frames, + pred_frames, pred_positions, sin_cos_true_chi, torsion_angle_mask, use_clamped_fape) + + self.experimentally_loss(experimentally_logits, atom37_atom_exists, all_atom_mask, filter_by_solution) + + predict_lddt_loss = self.predicted_lddt_loss(final_atom_positions, all_atom_positions, all_atom_mask, + predicted_lddt_logits, filter_by_solution) + predict_lddt_loss = self.plddt_weight * predict_lddt_loss + + self.aligned_error_loss(final_affines, backbone_affine_tensor, backbone_affine_mask, pae_breaks, + pae_logits, filter_by_solution) + # # todo check whether to use it + # aligned_error_loss = self.aligned_error_loss(final_affines, backbone_affine_tensor, + # backbone_affine_mask, pae_breaks, pae_logits, filter_by_solution) + + l_fape_side = 0.5 * loss_sidechain + l_fape_backbone = 0.5 * fape_loss + l_anglenorm = angle_norm_loss + + loss = l_fape_side + \ + l_fape_backbone + \ + l_anglenorm + \ + distogram_loss + \ + masked_loss + \ + predict_lddt_loss + + loss = loss * P.Sqrt()(P.ReduceSum()(all_atom_mask[:, 0])) + return loss, l_fape_side, l_fape_backbone, l_anglenorm, distogram_loss, masked_loss, predict_lddt_loss diff --git a/tests/st/mindsponge/test_megafold/module/structure.py b/tests/st/mindsponge/test_megafold/module/structure.py new file mode 100644 index 000000000..f6010a96f --- /dev/null +++ b/tests/st/mindsponge/test_megafold/module/structure.py @@ -0,0 +1,261 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import functional as F +from mindsponge.cell import InvariantPointAttention +import mindsponge.common.residue_constants as residue_constants +from mindsponge.cell.initializer import lecun_init +from mindsponge.common.utils import torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, \ + atom14_to_atom37 +from mindsponge.common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale,\ + vecs_to_tensor, vecs_expend_dims, rots_expend_dims + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim, mixed_precision): + super().__init__() + self.config = config + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.input_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(self.config.num_channel, 14, + weight_init=lecun_init(self.config.num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-12) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class FoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim, mixed_precision): + super().__init__() + self.config = config + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.drop_out = nn.Dropout(keep_prob=0.9) + self.attention_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition = nn.Dense(self.config.num_channel, config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(self.config.num_channel, 6, weight_init='zeros') + self.attention_module = InvariantPointAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + self.config.num_channel, + pair_dim, + mixed_precision) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim, mixed_precision) + self.print = ops.Print() + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype): + """construct""" + attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) + translation1 = vecs_scale(translation, 10.0) + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames) + return res + + +class StructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim, mixed_precision): + super(StructureModule, self).__init__() + self.config = config.structure_module + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.seq_length = config.seq_length + self.fold_iteration = FoldIteration(self.config, pair_dim, single_repr_dim, mixed_precision) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + + def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + # folder iteration + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, + aatype): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames = \ + self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (mnp.concatenate(vecs_expend_dims(atom_pos, 0), axis=0)[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expend_dims(frames[0], 0) + + vecs_expend_dims(frames[1], 0), axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = mnp.concatenate(atom_pos_batch, axis=1) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames_new, act) + return res diff --git a/tests/st/mindsponge/test_megafold/module/template_embedding.py b/tests/st/mindsponge/test_megafold/module/template_embedding.py new file mode 100644 index 000000000..4147965f5 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/module/template_embedding.py @@ -0,0 +1,262 @@ +# Copyright 2022 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''TEMPLATE''' +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore import Parameter +from mindsponge.cell.initializer import lecun_init +from mindsponge.common.utils import dgram_from_positions +from mindsponge.common.geometry import make_transform_from_reference, rot_to_quat, quat_affine, invert_point +from mindsponge.common.residue_constants import atom_order +from mindsponge.cell import Attention, TriangleAttention, Transition, TriangleMultiplication + + +class TemplatePairStack(nn.Cell): + '''template pair stack''' + + def __init__(self, config, mixed_precision): + super(TemplatePairStack, self).__init__() + self.config = config.template.template_pair_stack + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_block = self.config.num_block + batch_size = 0 + self.slice = config.slice.template_pair_stack + start_node_cfg = self.config.triangle_attention_starting_node + self.triangle_attention_starting_node = TriangleAttention(start_node_cfg.orientation, + start_node_cfg.num_head, + start_node_cfg.key_dim, + start_node_cfg.value_dim, + start_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_starting_node, + mixed_precision) + end_node_cfg = self.config.triangle_attention_ending_node + self.triangle_attention_ending_node = TriangleAttention(end_node_cfg.orientation, + end_node_cfg.num_head, + end_node_cfg.key_dim, + end_node_cfg.value_dim, + end_node_cfg.gating, + 64, + batch_size, + self.slice.triangle_attention_ending_node, + mixed_precision) + # Hard Code + self.pair_transition = Transition(self.config.pair_transition.num_intermediate_factor, + 64, + batch_size, + self.slice.pair_transition) + + mul_outgoing_cfg = self.config.triangle_multiplication_outgoing + self.triangle_multiplication_outgoing = TriangleMultiplication(mul_outgoing_cfg.num_intermediate_channel, + mul_outgoing_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size, + mixed_precision=mixed_precision) + mul_incoming_cfg = self.config.triangle_multiplication_incoming + self.triangle_multiplication_incoming = TriangleMultiplication(mul_incoming_cfg.num_intermediate_channel, + mul_incoming_cfg.equation, + layer_norm_dim=64, + batch_size=batch_size, + mixed_precision=mixed_precision) + + def construct(self, pair_act, pair_mask, index=None): + if not self.num_block: + return pair_act + + pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) + pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) + pair_act = pair_act + self.pair_transition(pair_act, index) + return pair_act + + +class SingleTemplateEmbedding(nn.Cell): + '''single template embedding''' + + def __init__(self, config, mixed_precision): + super(SingleTemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_bins = self.config.dgram_features.num_bins + self.min_bin = self.config.dgram_features.min_bin + self.max_bin = self.config.dgram_features.max_bin + + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.embedding2d = nn.Dense(88, self.num_channels, + weight_init=lecun_init(88, initializer_name='relu')) + # if is_training: + template_layers = nn.CellList() + for _ in range(self.config.template_pair_stack.num_block): + template_pair_stack_block = TemplatePairStack(config, mixed_precision) + template_layers.append(template_pair_stack_block) + self.template_pair_stack = template_layers + + self.one_hot = nn.OneHot(depth=22, axis=-1) + self.n, self.ca, self.c = [atom_order[a] for a in ('N', 'CA', 'C')] + + self.use_template_unit_vector = self.config.use_template_unit_vector + layer_norm_dim = 64 + self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) + self.num_block = self.config.template_pair_stack.num_block + self.batch_block = 4 + + def construct(self, mask_2d, template_aatype, template_all_atom_masks, template_all_atom_positions, + template_pseudo_beta_mask, template_pseudo_beta): + '''construct''' + num_res = template_aatype[0, ...].shape[0] + template_mask_2d_temp = P.ExpandDims()(template_pseudo_beta_mask, -1) * \ + P.ExpandDims()(template_pseudo_beta_mask, 1) + template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, + self.max_bin, self._type) + + to_concat_temp = (template_dgram_temp, P.ExpandDims()(template_mask_2d_temp, -1)) + aatype_temp = self.one_hot(template_aatype) + aatype_temp = P.Cast()(aatype_temp, self._type) + to_concat_temp = to_concat_temp + (P.Tile()(P.ExpandDims()(aatype_temp, 1), (1, num_res, 1, 1)), + P.Tile()(P.ExpandDims()(aatype_temp, 2), (1, 1, num_res, 1))) + + rot_temp, trans_temp = make_transform_from_reference(template_all_atom_positions[:, :, self.n], + template_all_atom_positions[:, :, self.ca], + template_all_atom_positions[:, :, self.c]) + + _, rotation_tmp, translation_tmp = quat_affine(rot_to_quat(rot_temp), trans_temp, rot_temp) + points_tmp = [P.ExpandDims()(translation_tmp[0], -2), + P.ExpandDims()(translation_tmp[1], -2), + P.ExpandDims()(translation_tmp[2], -2)] + affine_vec_tmp = invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1) + inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + P.Square()(affine_vec_tmp[0]) + P.Square()(affine_vec_tmp[1]) + \ + P.Square()(affine_vec_tmp[2])) + template_mask_tmp = (template_all_atom_masks[:, :, self.n] * + template_all_atom_masks[:, :, self.ca] * + template_all_atom_masks[:, :, self.c]) + template_mask_2d_tmp = P.ExpandDims()(template_mask_tmp, -1) * P.ExpandDims()(template_mask_tmp, 1) + + inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp + unit_vector_tmp = (P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[0], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[1], -1), + P.ExpandDims()(inv_distance_scalar_tmp * affine_vec_tmp[2], -1)) + + if not self.use_template_unit_vector: + unit_vector_tmp = (P.ZerosLike()(unit_vector_tmp[0]), P.ZerosLike()(unit_vector_tmp[1]), + P.ZerosLike()(unit_vector_tmp[2])) + to_concat_temp = to_concat_temp + unit_vector_tmp + (P.ExpandDims()(template_mask_2d_tmp, -1),) + act_tmp = P.Concat(-1)(to_concat_temp) + + act_tmp = act_tmp * P.ExpandDims()(template_mask_2d_tmp, -1) + act_tmp = self.embedding2d(act_tmp) + + act_tmp = P.Split(0, self.batch_block)(act_tmp) + act = () + for i in range(self.batch_block): + act = act + (P.Squeeze()(act_tmp[i]),) + + output = [] + for i in range(self.batch_block): + act_bacth = act[i] + for j in range(self.num_block): + act_bacth = self.template_pair_stack[j](act_bacth, mask_2d) + slice_act = P.Reshape()(act_bacth, ((1,) + P.Shape()(act_bacth))) + output.append(slice_act) + + act_tmp_loop = P.Concat()(output) + act_tmp = self.output_layer_norm(act_tmp_loop) + return act_tmp + + +class TemplateEmbedding(nn.Cell): + '''template embedding''' + + def __init__(self, config, seq_len, mixed_precision=True): + super(TemplateEmbedding, self).__init__() + self.config = config.template + if mixed_precision: + self._type = mstype.float16 + else: + self._type = mstype.float32 + self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + self.template_embedder = SingleTemplateEmbedding(config, mixed_precision) + self.template_pointwise_attention = Attention(self.config.attention.num_head, + self.config.attention.key_dim, + self.config.attention.value_dim, + self.config.attention.gating, + q_data_dim=128, m_data_dim=64, + output_dim=128, batch_size=None, + mixed_precision=mixed_precision) + self.slice_num = config.slice.template_embedding + if self.slice_num == 0: + slice_num = 1 + self._flat_query_slice = Parameter( + Tensor(np.zeros((int(seq_len * seq_len / slice_num), 1, 128)), dtype=mstype.float32), requires_grad=False) + self._flat_templates_slice = Parameter( + Tensor(np.zeros((int(seq_len * seq_len / slice_num), 4, 64)), dtype=mstype.float32), requires_grad=False) + + def construct(self, query_embedding, template_aatype, template_all_atom_masks, template_all_atom_positions, + template_mask, template_pseudo_beta_mask, template_pseudo_beta, mask_2d): + '''construct''' + num_templates = template_mask.shape[0] + num_channels = self.num_channels + num_res = query_embedding.shape[0] + query_num_channels = query_embedding.shape[-1] + mask_2d = F.depend(mask_2d, query_embedding) + template_pair_representation = self.template_embedder(mask_2d, template_aatype, + template_all_atom_masks, template_all_atom_positions, + template_pseudo_beta_mask, + template_pseudo_beta) + flat_query = P.Reshape()(query_embedding, (num_res * num_res, 1, query_num_channels)) + flat_templates = P.Reshape()( + P.Transpose()(template_pair_representation, (1, 2, 0, 3)), + (num_res * num_res, num_templates, num_channels)) + template_mask_bias = P.ExpandDims()(P.ExpandDims()(P.ExpandDims()(template_mask, 0), 1), 2) - 1.0 + bias = 1e4 * template_mask_bias + if self.slice_num: + slice_shape = (self.slice_num, -1) + flat_query_shape = P.Shape()(flat_query) + flat_query = P.Reshape()(flat_query, slice_shape + flat_query_shape[1:]) + flat_templates_shape = P.Shape()(flat_templates) + flat_templates = P.Reshape()(flat_templates, slice_shape + flat_templates_shape[1:]) + slice_idx = 0 + embedding_tuple = () + while slice_idx < self.slice_num: + self._flat_query_slice = flat_query[slice_idx] + self._flat_templates_slice = flat_templates[slice_idx] + embedding_slice = self.template_pointwise_attention(self._flat_query_slice, self._flat_templates_slice, + bias, index=None, nonbatched_bias=None) + embedding_slice = P.Reshape()(embedding_slice, ((1,) + P.Shape()(embedding_slice))) + embedding_tuple = embedding_tuple + (embedding_slice,) + slice_idx += 1 + embedding = P.Concat()(embedding_tuple) + + embedding = P.Reshape()(embedding, (num_res, num_res, query_num_channels)) + # No gradients if no templates. + embedding = embedding * (P.ReduceSum()(template_mask) > 0.) + return embedding + embedding = self.template_pointwise_attention(flat_query, flat_templates, bias, index=None, + nonbatched_bias=None) + embedding = P.Reshape()(embedding, (num_res, num_res, query_num_channels)) + # No gradients if no templates. + embedding = embedding * (P.ReduceSum()(template_mask) > 0.) + return embedding diff --git a/tests/st/mindsponge/test_megafold/origin_length/T1070-D2.pkl b/tests/st/mindsponge/test_megafold/origin_length/T1070-D2.pkl new file mode 100644 index 000000000..464911030 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/origin_length/T1070-D2.pkl @@ -0,0 +1 @@ +€Ke. \ No newline at end of file diff --git a/tests/st/mindsponge/test_megafold/processed_feature/T1070-D2.pkl b/tests/st/mindsponge/test_megafold/processed_feature/T1070-D2.pkl new file mode 100644 index 000000000..464911030 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/processed_feature/T1070-D2.pkl @@ -0,0 +1 @@ +€Ke. \ No newline at end of file diff --git a/tests/st/mindsponge/test_megafold/test_megafold.py b/tests/st/mindsponge/test_megafold/test_megafold.py new file mode 100644 index 000000000..f20156142 --- /dev/null +++ b/tests/st/mindsponge/test_megafold/test_megafold.py @@ -0,0 +1,241 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval script""" +import pickle +import time +import pytest +import numpy as np +import mindspore.context as context +import mindspore.common.dtype as mstype +from mindspore import Tensor, Parameter +from mindspore import load_checkpoint, load_param_into_net +from mindsponge.cell.initializer import do_keep_cell_fp32 +from mindsponge.common.config_load import load_config +from model.fold import MegaFold, compute_confidence + + +def load_pkl(pickle_path): + f = open(pickle_path, "rb") + data = pickle.load(f) + f.close() + return data + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_megafold_gpu(): + """ + Feature: megafold model test in the gpu + Description: input the tensors of processed raw feature + Expectation: cost_time <= 100, confidence >=30, final_atom_positions same as true label. + """ + context.set_context(mode=context.GRAPH_MODE, + device_target="GPU", + max_device_memory="31GB", + memory_optimize_level="O1") + context.set_context(enable_graph_kernel=True, + graph_kernel_flags="--enable_expand_ops_only=Softmax --enable_cluster_ops_only=Add") + data_cfg = load_config(args.data_config) + model_cfg = load_config(args.model_config) + model_cfg.seq_length = data_cfg.eval.crop_size + slice_key = "seq_" + str(model_cfg.seq_length) + slice_val = vars(model_cfg.slice)[slice_key] + model_cfg.slice = slice_val + + megafold = MegaFold(model_cfg, mixed_precision=False) + param_dict = load_checkpoint("/home/workspace/mindspore_ckpt/ckpt/megafold.ckpt") + + new_param_dict = {} + for key in param_dict.keys(): + if 'msa_stack' in key and 'extra' not in key: + new_param_dict[key] = Parameter(Tensor(param_dict[key][0:1]), name=key) + elif 'template_embedding._flat_templates_slice' in key or 'template_embedding._flat_query_slice' in key or \ + 'template_embedding.template_embedder.idx_num_block' in key or \ + 'template_embedding.template_embedder.idx_batch_loop' in key: + continue + else: + new_param_dict[key] = Parameter(Tensor(param_dict[key]), name=key) + load_param_into_net(megafold, new_param_dict) + + megafold.to_float(mstype.float32) + + data = load_pkl('./T1070-D2.pkl') + ori_res_length = load_pkl('./origin_length/T1070-D2.pkl') + input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype', + 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', + 'template_pseudo_beta_mask', 'template_pseudo_beta', + 'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index'] + feat = [] + for key in input_keys: + feat.append(Tensor(data[key])) + prev_pos = Tensor(np.zeros([data_cfg.eval.crop_size, 37, 3]).astype(np.float32)) + prev_msa_first_row = Tensor(np.zeros([data_cfg.eval.crop_size, 256]).astype(np.float32)) + prev_pair = Tensor(np.zeros([data_cfg.eval.crop_size, data_cfg.eval.crop_size, 128]).astype(np.float32)) + start_time = time.time() + for i in range(1): + feat_i = [Tensor(x[i]) for x in feat] + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = megafold(*feat_i, + prev_pos, + prev_msa_first_row, + prev_pair) + end_time = time.time() + cost_time = end_time-start_time + final_atom_positions = prev_pos.asnumpy()[:ori_res_length] + predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length] + confidence = compute_confidence(predicted_lddt_logits) + true_label = np.load("./true_label/final_atom_positions_gpu_fp32.npy", allow_pickle=True) + res = np.allclose(final_atom_positions, true_label, rtol=1e-05, atol=1.e-5) + print("cost time : ", cost_time, "s") + print(confidence) + print(sum(final_atom_positions-true_label)) + print(res) + assert res + assert cost_time <= 100 + assert confidence >= 30 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_megafold_ascend_training_false(): + """ + Feature: megafold model test in the ascend when the training flag is false + Description: input the tensors of processed raw feature + Expectation: cost_time <= 200, confidence >=30, final_atom_positions same as true label. + """ + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + max_device_memory="31GB") + data_cfg = load_config('./config/data.yaml') + model_cfg = load_config('./config/model.yaml') + model_cfg.seq_length = data_cfg.eval.crop_size + slice_key = "seq_" + str(model_cfg.seq_length) + slice_val = vars(model_cfg.slice)[slice_key] + model_cfg.slice = slice_val + megafold = MegaFold(model_cfg, mixed_precision=True) + param_dict = load_checkpoint("/home/workspace/mindspore_ckpt/ckpt/megafold.ckpt") + + new_param_dict = {} + for key in param_dict.keys(): + if 'msa_stack' in key and 'extra' not in key: + new_param_dict[key] = Parameter(Tensor(param_dict[key][0:1]), name=key) + elif 'template_embedding._flat_templates_slice' in key or 'template_embedding._flat_query_slice' in key or \ + 'template_embedding.template_embedder.idx_num_block' in key or \ + 'template_embedding.template_embedder.idx_batch_loop' in key: + continue + else: + new_param_dict[key] = Parameter(Tensor(param_dict[key]), name=key) + load_param_into_net(megafold, new_param_dict) + + megafold.to_float(mstype.float16) + do_keep_cell_fp32(megafold) + + data = load_pkl('./processed_feature/T1070-D2.pkl') + ori_res_length = load_pkl('./origin_length/T1070-D2.pkl') + input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype', + 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', + 'template_pseudo_beta_mask', 'template_pseudo_beta', + 'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index'] + feat = [] + for key in input_keys: + feat.append(Tensor(data[key])) + prev_pos = Tensor(np.zeros([data_cfg.eval.crop_size, 37, 3]).astype(np.float16)) + prev_msa_first_row = Tensor(np.zeros([data_cfg.eval.crop_size, 256]).astype(np.float16)) + prev_pair = Tensor(np.zeros([data_cfg.eval.crop_size, data_cfg.eval.crop_size, 128]).astype(np.float16)) + + start_time = time.time() + for i in range(1): + feat_i = [Tensor(x[i]) for x in feat] + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = megafold(*feat_i, + prev_pos, + prev_msa_first_row, + prev_pair) + end_time = time.time() + cost_time = end_time-start_time + final_atom_positions = prev_pos.asnumpy()[:ori_res_length] + predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length] + confidence = compute_confidence(predicted_lddt_logits) + true_label = np.load("./true_label/final_atom_positions_ascend_mixed_precision.npy", allow_pickle=True) + res = np.allclose(final_atom_positions, true_label, rtol=1e-05, atol=1.e-5) + print("cost time: ", cost_time, "s") + print(confidence) + print(sum(final_atom_positions-true_label)) + print(res) + assert res + assert cost_time <= 200 + assert confidence >= 30 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_ascend_training_true(): + """ + Feature: megafold model test in the ascend when the training flag is true + Description: input test + Expectation: success or throw xxx exception or result == xxx, etc. + """ + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + max_device_memory="31GB") + data_cfg = load_config('./config/data.yaml') + model_cfg = load_config('./config/model.yaml') + model_cfg.is_training = True + model_cfg.seq_length = data_cfg.eval.crop_size + slice_key = "seq_" + str(model_cfg.seq_length) + slice_val = vars(model_cfg.slice)[slice_key] + model_cfg.slice = slice_val + megafold = MegaFold(model_cfg, mixed_precision=True) + load_checkpoint("/home/workspace/mindspore_ckpt/ckpt/megafold_training.ckpt", megafold) + + megafold.to_float(mstype.float16) + do_keep_cell_fp32(megafold) + + data = load_pkl('./processed_feature/T1070-D2.pkl') + ori_res_length = load_pkl('./origin_length/T1070-D2.pkl') + input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype', + 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', + 'template_pseudo_beta_mask', 'template_pseudo_beta', + 'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', + 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index'] + feat = [] + for key in input_keys: + feat.append(Tensor(data[key])) + prev_pos = Tensor(np.zeros([data_cfg.eval.crop_size, 37, 3]).astype(np.float16)) + prev_msa_first_row = Tensor(np.zeros([data_cfg.eval.crop_size, 256]).astype(np.float16)) + prev_pair = Tensor(np.zeros([data_cfg.eval.crop_size, data_cfg.eval.crop_size, 128]).astype(np.float16)) + + start_time = time.time() + for i in range(1): + feat_i = [Tensor(x[i]) for x in feat] + prev_pos, prev_msa_first_row, prev_pair = megafold(*feat_i, + prev_pos, + prev_msa_first_row, + prev_pair) + end_time = time.time() + cost_time = end_time - start_time + final_atom_positions = prev_pos.asnumpy()[:ori_res_length] + true_label = np.load("./true_label/final_atom_positions_mixed_precision.npy", allow_pickle=True) + res = np.allclose(final_atom_positions, true_label, rtol=1e-05, atol=1.e-5) + print("cost time: ", cost_time, "s") + print(sum(final_atom_positions-true_label)) + print(res) + assert res + assert cost_time <= 200 diff --git a/tests/st/mindsponge/test_megafold/true_label/final_atom_positions_ascend_mixed_precision.npy b/tests/st/mindsponge/test_megafold/true_label/final_atom_positions_ascend_mixed_precision.npy new file mode 100644 index 0000000000000000000000000000000000000000..f93e22d40436de324f6538c1ab179d21327e5813 GIT binary patch literal 22550 zcmdU%33ye-6~`q7L|oWGt%a~?009a69%M5cnYm#Cfe;kDvKz>VJR-uYFj^QaqFVk_Lud0kC*%2&4VP5d>`+-%$+;;oH;XR`JXd) zQA}KK-M`E&6K|OsGAw!A(2S4{!6BW7w+Lwx95OsTV|+%^V}sK(h9zgW_ex3|mrVP( z)TA-V{J&vn^Ux;2E!*<1QIp{5!G8QzJfxPG1i3}8HbbSMUT#LqoqC6{q@{k^q{uBf zfA6buX={7ewP#6Q|4aSeWJrCz%?y_|`c*SfeyI-`BX#rw z(_gCUizd+r&6V_Jc)xy{Yw4i4K0d~|O;?R;A@}K!IF;A4x~?Agi0xB%>7AyBoKu@k zJK3+^G7aP#^;^?IepH9eFXbyWm%C0VudHXAI9a2n82fj5j zXoeK>#xNCc`WDijJLYVa$Wsrhp{BcApW|wSnI>1&UNcde>D^|se4!4R!O~f;HIrl^ z>ur!*e@_bMsZULWk3ODHIce#wE;Co}sEY->J9)~rPXp^K(@^fPelow3`c{>=I#S;H zrzt0qR@J!LZZvD=qk;o`o#_QKjABojZ!O|oVI!cYiu;vy=~v_IzNGYeNpX(m(yyK z=_HfXdbs7IDgystRaG~wRN8$C1+=Z1nKw|VJ_!^`4;&nWVKh{<!rg
    mY2;w5PAWhEeg?KDiu<6rfkDlYx%esK=xDi_P9bzAMzuj{le$pe%?Xp7; zFYw_%)M@5UTPZ}~RZf*r_{?cCSHEe-=k=V+4w_^8l;?W`v*kYYf0XW#;W;QtvbNor(rGxT*$Y(d>Gb zCt^*IBkWuyZxNffmp|f(MZ@nQ6UvUB=M#Cz9l1PwrO4yz*F@-c?<)EBVCx8$uC{d+ z-PnnkG*TK_$Kk7$b+sUhFKg|=7j0vGjCHo7u!^%s#!3^988|R!( z|9S@LWhNc&R+$mf6;7w-(I3k7vu@n^H-70pH4-iid;cAzHUxPY>47F`gJnNw`nhq} z&_!OOi2fS8-7Q@%?A~4(k1Iv~rM6?IYUE%km(?5Cf(}GLh2h>?&=SRi&Dv3TJ?(L1 z5c;t)_h8MN5%ola{%r+YsI31)uEcfZn)>AezvW>(oSQ**PwS&_G+m!$H9o2Tf*oJX zD!pG`*H`eYa=FXY>u5PZdm}e~U3m}8>+e7tEV-=z3_|;}{t@gb$hwGJ>fo#7=UKke zM@(G_wZuGhlX16zbr>&MNGD1@(!1DkNAz!5tDnJd(C)@S zvn5T55u&l?XOLH2aN1V*D(A1k?keKRg5K$dwW z8Ztw@UtrzdSM!;HdFqt{?mWl21X>#&aFXo;ijsiTvVE0UqSr8d%T+`}Y5D`79!L`^ z<$PH~&v_`xmJ95=X!3W**=3jYr&zR)@y9~={T6{r!jpN^( zXx=*-dv_8m&={0&8d%>acr5MlqUVHG zTQaztuQ}~T$Zekiw$@ZEvI+Usdfxk|gCE{QZ-*bMR8r0H5$i2^u^SciamG=W_-$ra+nXFbg6*%jt#%ZHJLW>>-0c+!9eXdi_ z70HWESJS{ToKslTo~_hWb64K@weZ1ulOI2Vwp^nw5S3q5M_DcX)j70EIdrM7TBF>z zA0zuvnVq}VA*W}_GL)x!*`}T@6kYM8ZHUUG;^psZwnK@BtBFL%r}1#6`8i7| zY8aW08EUMT5z1lj9;V7@HTj(<)mnUy+tJmn$q{b{p?wgn%opYyf~*4-M=A?#{ivkM zLPPFQd*JK5dI{a~9eSpz>?3dPo-6T`T5d}Pkk2G}8d(HVRc9!k{|G7^{y-&d^&DK~ z&uR@gZF~I80d8yNJ++{qKH)CxzB4OtHyb6>UJk3BLX5Br|Kxk(f^a!Y4yY%&lG6pK z@@r}_8uebHyg)@|?dGaA9#?PbOtP?UU9kvJAbwk@T4_!$FUC!Fxg%7f|NVuu+l^nk z?Wc&^df}gB5h+}tUL}c0c0H2Xgcco5=4+A9eH4T*T6z_j-e>PqDvL(S*W^_PNkwZn zTKZP7zDTKOZN=(&KKnK3<;}3;aw(5xH}%8Q&eoYDLD%!Ek{;Ll3q6nH$!{RH%(XVM zCTCguSzQ~g56~)8xNairt+0Mz3ih>(dE6_vgg*e^CcY5uJ#~SBVi#*0>!w*jXqUhR$p#Gl_CN9XX zg{uVCH&Fftck$#aPZW5SxjjFkrg2yvzm-9m0_og=-3&(&!la_*oK@m)vxmI$aB>+^$xsxa>( z1oGVs?0B?{V3)O%NmRbICv&%#7-uObl)LB6a4JzlFxK#WFK0QQK+jYuA&EGj&%xw? z63Hp=!lp;E8Z|iW7NnMr-{M}u7=jT#|!jWJ;M>EOQ`g`)oblLCV-t=DuQ(DC>)KjKE+F!XZ-_jB#kp8-6hisip%#*dRV8hR*Q-L^7oy8{FeMv z;re&1%yv|*#7G<#p)Y>^F7iH8sP{`OL~1okRTf2^pHO&-;jGpH_F7qTuNJoNAT@pm z)Fr&lqK+15{X1jL6F3dJ9jOIl(G2=z6Z@{GM>Wo$)Y1a%E0{Wvo8h!YH|%&iQTcIv zq%rz)qMJD&21CgyUjd_jg8H{XjwG0!iH^idE^}BhHH)uF^dh63$y=^R;^XwxYwBuT zM@t3`i{(tyK0KBfEO>8l+TBFxt@RhzJXM4Ey>%WXhJAW7)EyTXeq0)5kuuA)~zqy z_zWISpyq}mklFxh6?S6X!tg2+kmH-gdc&#J>`x`?8lRNE%2%!Oz3Uq*!9jc%r5WFy z*yVQx$O`z1mFK7m=q4+2>cn0~e>Rc*#CkFE8|<1dw017?z0nBm1~mCE&~LApGS{^` zdf-KmLsBnNCpHV~HYsnkSyUg*yRMOMtQ`FtEmi-KbIMWrOJeh0`oCna5;*ZxmmKA- zaUDV6PdcoYNk;45!x^_@ra}?sq&YrP5IE{N^S`{R`&f{xI|>=$OSRKnwyn&sO1=8` zXFgLW%{Oj+Cem(E&HgTtcqUj%5;bEdhy~W6VTRGCb#vY05f zV?W^fYp{q!=^T5I>)a9cT8G)26=UDL1_N|5e(Co4iHeRR>xJNJE8 zx^(*Y_UYDMG~qaKp{}ybx`b{VZXHAWEwX;VF0>=3?+mwWO?5$C&VQYv62L!7isPyO E1LWz8CIA2c literal 0 HcmV?d00001 diff --git a/tests/st/mindsponge/test_megafold/true_label/final_atom_positions_gpu_fp32.npy b/tests/st/mindsponge/test_megafold/true_label/final_atom_positions_gpu_fp32.npy new file mode 100644 index 0000000000000000000000000000000000000000..3ab958873c17c0fbcfe971935521999557aa37cb GIT binary patch literal 44972 zcmeHw2T+tr*X|4hGeZW&oWKO;u;%NBjS`zQc$~GyGZojBz8T`m_J{saLmNP0ybi zv;V5q^jzfm|MBNAfHJs0I!PeAX`uT89v1YIVe<_mEi7(_-$hdNd26C60embTsldYB zdUBsF#1umauGcgX@2|hW{;fW6>2*}7=c1>y9;}U%3LEDd%-a0boEK=RJSQ%bbLWL2 zc!$R7-&*SPpIYy*Cx>fyPL3^S;)&PD7Q+WBaG;HeQVt3cohU`K5=L@%7hvjO748ZQ zbo8YFn)pz7Mj6QCS3WLQ3&(4|F<&_U)PoPd%A3&QxGrD)6cmjPdfPzwU<-z+_4%65 zQjUkv_-z;y7)L&VHt0EHCzcm6?{mv`K8`qSM@tv$zkRuf4SdfiaJ{vTiaTXp%Cn)t;tv9fohju3|=7qj}C9vRK~;^$F%- z$HX9bR5U+la`%(F{Oh>-s@&1 z+)ziLbM-jtWAH@t)0^RbKZgELRYKW`+i;{|EUinafMQiPqsYT(YP!q^lcn3BcZoGy zpSiI;{J1n)&5FRlSe?}lMFF2fU&BG#t1h-2Y>Ckvv z96Bh&gMJ2T#rEB%k8+fXH_@SJF$|rh=z3U3qvD0=dR>M!SqT&sF2a(x5_~)sS75YI zuCYU8#Sj#!97m>2PMEPu3YnYbb&{8F{HL~wo6JL%ffCGC=&8n40aDH|+Oa#jNH-y- zb(UjKKNFSp5yRvzLxXi%vI`KxZ@CmxoD=BTco9}q3&rJk`8zA7JJ?`s&k)SaF~4j6 zshAhkckTGa7R}CuqFF;Dz3~&_K&k}gUg+q+bthDDkzwd5Jx$9NVc&}o9Nn#Md{I$x;>^H+r8{ll+wC;pRYVePl1 z|Hl#6rppm4Hc&%X5t8aF;ozaCg;8R7Rt!VSntG};+72ZrDe$_5hPp3y#-fX1aBZNc z+f2`?tsq0gy5DwB#Wdg{yOI<~O)<1a@QJI^QVMxT%lZdvG5Jy2N0#PeK(rheT}-t1 z&@j@L1-L^qDM>P(f*HY6luesLCLoxVV z!SOBge@){KmJ7xorw!Eg{$p<1_7F5YS13DtSx@-Z~tGkTlrCh9fPGPnqs6I^=)x(jU3hI8Y%jT5bxWlFyNw* z9*6Vc+J*6f$8q8*C8&CZVPH>+p!Lw&3%DU!?qU-eIbtyC>jP(2O zC1Cqf4(%=zJ@+qx{T>nwou#LyElcA{fE**D69_yH%sUta>NUHUH zoQviAW@|plt2oB>+K9vp=6d_Q z(DbZIjP#RIsDDeKVM$JSHB$!PPbTU!%niE^D{#s_!ECIC{oT;CixM@aB%0OdWvs$B zpAiTwC>$I*2(XaxBv(7=>A{8^jyFJ#;GtSNPgz`}wMrZq8b|9o^0DKF1jf5@l-Pt1 zQA;)6pN*yXbv!KT5`xK-^SAGEf6MQ$)49YmK{#7NORf`a5j7zgZcPi0kB#4$#tj_0 z0aKQ-H<1_tDmrh((p}cSljfeYwdJ*rA30t9wU~RzNX`6mxD)Hvqev&~=ioP&hf%K- zFqJb>jpiS?>P&y)#OZ0tpby-z4ob98X=vA$EG{Tff_nxXbvwkvgq13A_u{D5i8QX> zjS$2dtsmFY_wHQyiR&<$ou9p#-NoovTt;#b$`3C%+F2U+N374-UIDL5s?cb=flBVF zfJ|{1OfwkYXjciovm>y-g`Vh=7gp(2n3WVqBOZBS3u`N{Em*{a8WabK;l-`cw|X zGk*hlH?c=qcO}{@nXS}ce}R7$E!(91=xqWH*y4FBCZmZ=B)Z2#*I9BTGFeY@Mt}$2 zDm1%lB;k-u?u-x9n=RH;<>g`&>%w%)Esa$4-D~bdZz%#D3+~L!3!Xo}e#)KPD}ms> zmX4Hq#cf(9!3lQ8=H<8lq&D8pj)>kT$K)m^TGCL6eOhLRDKgTrdqSKeHIB|Pkouf0 z4)>Jf?nE7PEQbI<>KD|b=b zWA!9vmr~hCb6|(-u?hs9)>E=nh@mIdxZ7DzrgF9z$J(s3)zYN;j_A5J42y#-^AG+N zZO}DDf_eVt{Aa5dVt9>U?-aAX4>|68qlLBe!EYbAz;O~_0<(XvXAAqeQfxLDX}XsP zHCD;cx|z9r7c<2U)tK#y&yqs`;6 zBYQ)i#NKd6n7*yP-C9h3l=e^4T#=u(9j?>4|i*iK~FmMy1nw#?>Z=Sm#ZI=w~6Jn`Wxf1BON{Q7S3(v+c>mcxJtcP1v z8m4G9d@k#$!v0ctB@Kr#TuTuVWw2=EW~};;j%E!gi`SdO;Q29@dfcsmfa9AXUT(hc z#w{p?6Fw?DxEn(~+ExVFh2cUWKlvxgz6$Mz<`l)RU6t6v-g@FkltdXf1@6~0Qk@57 z;o+u2`Y{7lT~QjZHcN4py{ETpS00}t)abFqL~B=z_cef4>S=-&_A~>3rVf<$U8G7H;QqB^4v|?|0^-Cnsn)yO>{}_r*wao7|YrpYT4L<&;8VWb2XS4RV z|K|1;g<{-FlH#RFPc3@dA)EP5wydWohr?n#xS&D{#!m(dZBb#h0#zHbdoV*>hAk{EMGh0QMw^x>D1uwnB1!ze8k(Uw6`QeoUt6Se743}d=U5zOAx^D>sD zHjl=ZgwJdVZu(pQ-e;-L{M0K%#V{!+1YK5I?<Sx z7N0Z2YS5(Om_C@<{D`ghCp}uV0!|i{K*V?wV@pp&U6Y|~ITH=uR~fo&Il`0UDe1Bo zeig}a(@RflSNh;rl?=<)n(H|wDP=KSDZ|wPTHqME@cYiW1SvI(QS62 z{QhMf1m<23xyxrN2++Tq1fAa+>Eb~<$VrAPZbmv6?*Kcd^A}0Y`F~j(CxpI}!efRx z?Q-{7SpJMa;Mc~Xw4(qaoE$qsFVUV@8)WQLV%Dp8+LbLt&#g+-KW(ITewkcMwi547 zn#d(VjKK@#SoFj|#kalX0;Z~=+W0rpar}Euo~p*Xhl!NDFNOP1Gz{l<{yW~#J^I9wYPj?ACL>7K z*`ScGezo zP72sWFn#7B7hD@0jv=GX<>>cIg@~Fd#s00aH2SnFw#<-Y_@=@;>&u!5{2FUR#`GkQ z{Fy&8+egY#cGx>ui9vn!bg7XDPPM{t%H2Sn(giruMu{k=|x(-+30{ zBRNur6+6xM+DnZXQzt3W+fYc;f}rX688)c*QGroHCJQXio|@Ra;VS*i&PisMYwda~ zx<8Yv5fy}$bFAl4tsUz}ZyzobVgD*Q(kdCLZc`Dw8!HgqLr2?6IiSpS6@1TX>8;!z zPT5k-QD~_22xlBQug1K7`O|^SXLa}9iYR3mek!S_lB?X|^B*O;`^Qo2VK*pFM4*8{Ph(#?;ZdR-2TS~YXJD|i zJ*JeD!h{z1OXKYd0RRtFUNaDBl#&|1t3s{NeTV*EZZ6lu(dYxM~y ztp7=hlXXm#>ga}qZ8AjqCeWl&ZfJcf6px!TKjIuWjM^xL@0>(>{mvP;x`v|UQgfa& zKXa>j)D;I~B-q|9jx>+$5imIzN6Hsm|K;X6BiT=a1v(<^5}~g!7zv;< zaEIMa1uonEjq>>i6JWbJ2>m%9R%or(BGYbxru)3*S_Gk&JUM1?$ zQh?rbC8*o1;LfyZ6W(*Falw$kx=8Qt3Xt?W!=zuq&EHx&zHR!6(=83cn0nTqnJ=%q zLK-E-EL$U~uDfA(Jp~Nx4YbL#qR9M7h3&13wCHhB+*iwSB27ax48<`dK#6-L{~bC* zYiaQ<+ut)fz~Qa}UX9{O3L$PBSD|e0czT#>i>@X$CTuX!4d(Y6&{v7lOB3kOU?CPr zRLEInS$4j9B#ZOw$>eu8bG!HLuWc}_nH=@on%95J@BJr@wKdrW&08xl?TCSf*3RO3 zv$)?wL-kam+y}1zY9%7QwY1MIlY48FU=O?F+`2RSlL0DtY2s*FVj3rAc47-1%=wx8 z;4^mpb1ow#7z=9{X$UWe8#yop8n=Sm<0Dq2aJ^dvLDV*$;#{*h-vdF2scs(2B0uxb z%Wpg1dcj@IUXLfeF4OSKuenD%Hey3->*p|?#dh1uJ@LLO#9X zuH2QO9g82*g~!LVEfQ2=F{#G6z2Y31-O(n`lq&;evya$0woTynBG%dKX8qhj-|ym7RRr3Iv2onDgEpUrJ*qfBK{rnr+ zA3Se|T?=JMs~1nU7kH?3Q;sn*6D3_@u^rB+aQujoV(VscH%G`YW~6}v{e%$BR^tSV zEno>BO^pmr7ZQre9nIT3PUInl*?*R_E2PC2>;f%en3wtoRe1P$9P|BQv5fYg`oLXH zRG{d37B}qBM=lF$oZ;)Ki{S%zoSlv3Pw2>RkpR1#!my;Bxi0$ftascY2MLa4>goPA z8#uDJyyc#SG%5%hS)yZ=Q$iG7DuwqF=IeUH783*tWc1NdW^XY9SuT=z7TaXrHCyPC z6l)c?P^KDOg0juIqj$lhKxS9q8e zDu=z!K+h*+b93czd-#>*Q@M^)`o7T?ZM{Iw%Rw-=BlA!E<6ZK3gjSao!h-AL!g0AK9zOE9Vt}vgG zhvg8lfZ3%p9;U}IFYM^3K-8+d=`)S@N+V~p1nbHboId$S#QAH?+3vPIh6hT~j^(du z+s7W~ljW%LL`UAgIl;fC3U!!oSMG7`Wv;lD!ED-DuCCnmh2_r(1b#~#jzrbL;xH-h zjWy8Vb#>s^f@mHnV#L#s6D)9VoY>M6gzI}Z*EsK*gpR6ja^b{OogXqgroIazc zj8;Od*e8M4T_g2AV23%r%$KCHmIiwZF}!UE28}b(+ZqClpQb>}*l$a_UQ6t7@{t1X zTw^GGKn_>+7bW^s%imkZx7;rUHP-Xt4sd0$;QSc9=G3)E;4vlYGC%RQoFin*!?10% zhSo;1c;Lt6m~9(JzI|M=Yfv~U2WsicP%$>umm*_g!SRsSPdOk`AA;S^D~ zpqudTuwk&4rdcsg;8^_37JLKs8|s8xb7e5p(^BsqZYX+00gsg$#-BK&kCOzKM#a&t zP!HUGsKAU1TDrYJjA6Y)urICPXgBYv9agsv!G}depC${jJT3_1n7ydA-({)2`$Rr! z?g_###fe(H;sdod;?2+BcF#7oP2=VUug3;`4ApydR51HCYQZ9W3D4@?^y#gToiD>W?pXztHS!v2n2o&92R`ELBL@N z9>$p{(jkYdWMr|7hMGvYmyfE<4mKv(NOxEcqo=7-tZ!zZB4<8vYKan|6|ZeQ>!Fgo5{D8%`28Jv5>(~S0vUOVLwFEdf4D||d1rGif*BXxU_&3T++ zd3%Eml+;p)eofRkyVx=%$L7Xr+>sjS*>E9v?2a2gTSI>1T`=jB3PU=xT!))j?ESGU z&TW_c@tJ|mdEj?aF#k#G-~IEl_rFFP-Q9G~afuxF11{3N*#aDkRU#GNmZ$ouYZ`aV zNs6$h@pPtiHuq{Hdwc!EM87!MVmXVMIkH1MRc3Z9!HjR|>-Jw?W+6^JmLT0bo?6l? zPSab38c)sT_pgDKr9P7rQn&}xSw8M(EcWt_7u@KjYBbdp9B*T3-0z`2VKEZMeY4vl~jRJgTK_DKTvWEH_%6c zPK|WrAmyX>YZVT6Wii{EedK12mSKn~jt=w|;tI2!UN`99v9_(n^INt*-^dvkCbF}7 z6^n5ya>AP^6^3VPC_Kp-9e0LfeQ_<-Kj(tl?Ub0XC6-cex?^HzmJfP}Ilpl|$`SW} zkwKIeLyg^B5nob`vNf&0zfK)-LB%8q&O2KF+i!{6_b}F-HIC@Oaw251_|MhHh_LJr z1@>ImlOUbtP&2Z#k>%~3Hb4lw!*aMgXi2lx8QQ=wRQ~Y4+;Nj_Q951%Wkd75*}JYI zy0aMmBbVimxA`7$`yS@-ex568woo9so`JRub;KA)6*|_{Q<|MKTBe3!x>8RKn>yem z%X2ZK0rTOCc1OUva0uA@xTX2`ydi>?#T=L$YfdAta;_+IQ3+j?^*RJgWBr@!Cw6_u zm1!bFCl=FN%VOG8TqZ-$;|3boh2>Cw6p8~j2AV#Dj|FiGv}Lx?HN`yG)DFR$T{>#K zk%wPasbONg&EJKgJTr?sbw&+CD-#`@Vgu7s=1P+3e9LjS{mMh%@K8uH3{*SP z4&A+&&p3+(XCtvkQZXrNo?`La2Z?dTRf4GBw6w>c$$jlZacg7({oX@}6Jw-k-`jkD zeeL-#tS!bf7E34DE`~%N0z85oMW%n-IehJJ%uD}sc6noouL=!aP2^a_8)a^W;o>0! zU3p#|evcyH^3X_~+A`FdG;A9g$U0?5rGDa0bCK>43GZ8`` zDBxC8L(lHJVr(r&>j~!iE=yw__prsHs&YIN{`>ZYB}@vc{x!4vdUA*5a$>pd+~)An zwX+Iyck9T%8y|g3g~K&TPn}uZhIuT{wR;4^pXESTy$i><$A8})S5V*nqxaLci3qLk z$S|phk;2E=VDVTM8)vqGMjsGhkDnSh8yaXO%Z>G7oE+-gTC(fM^5wDo%liLC57uu) z7AFdlz`2e&&rnG!z?}th;9TJz{IYHW^Rs?L?ZhZ!qsDGsJiYO8!bn#&95xxrjm6CL ze-w_@WsS7@p&K4FQG(anMDLcmAZ6($9Lq3KiDXwC=p{q7EqOaX^D@uMUhWWelp?c& zo*w_~jA1(^Sh=R)d`bG$aYJ2I2*TFssMBF5%zGaUQM&ovotJqQRGZ!D&gk`;<+x;g zzO{Y2c{*Vav#0HKUr&+D|C(d*_q;0WDX_Uc!dP6RIz9Dti1D8TN-E**tD)_-LJS@p zfd`(J#}c&VqczKe)?=Xgd-O;L2SyWSn^Y*@{Yj?4KMU+%tIK1UJnr*Ch6D3h9+~EB z-^Iw_qSDcT`EIzwVpXj?qNRuToKgQpD01{#swMD1Hp@|RwT+%`_}L@lZ7>oO3XY%L zKT3pSVL?!5=qad^5d6}?a18o7XKdmE8-%pjh-XVQG?@7nUOXF&URTU@ajvZdsG6|> z%KDad5xMcN_uwNpHF5)vR59--cNj-?Pvg2Vbxp^if9IdaE)f1u796#A2Ubs)0_7|9`wLZ|6+a zonG){xqbNy%=e(B`4`rIMj$XRIN0y{zJ z##aoK9_oxiJ(U<*UPJ!(-BF~j3Vuunb$Xc_f=p8v@(8TLOj(7rXcm?~kh0`{|9PF+NJ*G`RYjz&^4U!Jq`6o_Ye9wpVCFlhmc zpLy0mNBuK7eSi$A%lW(S4orT_J>;=ig3A1@fp1B>{dZHi&PSyyWiR6UiqIZX0?#|8Z%L$b$tt7YvH z%b0ISFBa=E>pV3X_?BxaSE9#B%l*u@cSPfDGW2pakmC_Cvu99Z_81S&sX->-oDwbNBgAEoYXo z!_M0hq_8-CLG$^j#dwda=i*7Eu*HEGC4Qb6PuC5ZoI{EXQPYi7Wq>^l=ai^$$$a1C z#wUfH8<|oGF8*J3(Ym@kBrTGmW>%pr@CC2H+_+g=_Sr3fQ(YNsPb5&!oKIZiCvxm; z5>JsIY|y!(3f1lz>BNf<9M?sL+AJ!BY5>uq1ST*`95xg zzH?a2oD6f?t?p76pH49wFvjOpFf>K0VG5KDu>QXLyKwjs#~mR0#0kerAuAD2zxL%J zgvF%VF~>+x$PQlaN{oDDzV9O6iqVVZTuxkRPP-rR+kf;qcXg?PMNta$EM}l`vnnIG zm>RYGbtJp&gBM=msF1_r)^4eclHFBU5fDdBXV!w(*f6x3pr^$nD_{?kDe4tqY|_ws f72wlfhRq9Ns8doobVyaf-k3Mp>PM%Awe$ZU^G_hO literal 0 HcmV?d00001 -- Gitee From 1e462dcdbf525dd50880b654c19855ff088c8e03 Mon Sep 17 00:00:00 2001 From: l00500167 Date: Wed, 27 Jul 2022 13:01:14 +0800 Subject: [PATCH 2/4] debug Signed-off-by: l00500167 --- tests/st/mindsponge/test_megafold/test_megafold.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/st/mindsponge/test_megafold/test_megafold.py b/tests/st/mindsponge/test_megafold/test_megafold.py index f20156142..e341530a3 100644 --- a/tests/st/mindsponge/test_megafold/test_megafold.py +++ b/tests/st/mindsponge/test_megafold/test_megafold.py @@ -44,8 +44,7 @@ def test_megafold_gpu(): """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU", - max_device_memory="31GB", - memory_optimize_level="O1") + max_device_memory="31GB") context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_expand_ops_only=Softmax --enable_cluster_ops_only=Add") data_cfg = load_config(args.data_config) -- Gitee From a51a9a1458a2dddc7ecc1f00347b1894162f707a Mon Sep 17 00:00:00 2001 From: l00500167 Date: Wed, 27 Jul 2022 15:23:15 +0800 Subject: [PATCH 3/4] test Signed-off-by: l00500167 --- tests/st/mindsponge/test_megafold/test_megafold.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/st/mindsponge/test_megafold/test_megafold.py b/tests/st/mindsponge/test_megafold/test_megafold.py index e341530a3..99f257fb7 100644 --- a/tests/st/mindsponge/test_megafold/test_megafold.py +++ b/tests/st/mindsponge/test_megafold/test_megafold.py @@ -47,8 +47,8 @@ def test_megafold_gpu(): max_device_memory="31GB") context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_expand_ops_only=Softmax --enable_cluster_ops_only=Add") - data_cfg = load_config(args.data_config) - model_cfg = load_config(args.model_config) + data_cfg = load_config('./config/data.yaml') + model_cfg = load_config('./config/model.yaml') model_cfg.seq_length = data_cfg.eval.crop_size slice_key = "seq_" + str(model_cfg.seq_length) slice_val = vars(model_cfg.slice)[slice_key] -- Gitee From 0e0eafc7f6356dfc308e08507ea7bcacbc06438b Mon Sep 17 00:00:00 2001 From: l00500167 Date: Wed, 27 Jul 2022 20:15:20 +0800 Subject: [PATCH 4/4] add Signed-off-by: l00500167 --- tests/st/mindsponge/test_megafold/test_megafold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/st/mindsponge/test_megafold/test_megafold.py b/tests/st/mindsponge/test_megafold/test_megafold.py index 99f257fb7..7b36b56bc 100644 --- a/tests/st/mindsponge/test_megafold/test_megafold.py +++ b/tests/st/mindsponge/test_megafold/test_megafold.py @@ -71,7 +71,7 @@ def test_megafold_gpu(): megafold.to_float(mstype.float32) - data = load_pkl('./T1070-D2.pkl') + data = load_pkl('./processed_feature/T1070-D2.pkl') ori_res_length = load_pkl('./origin_length/T1070-D2.pkl') input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', -- Gitee